| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- """
- 通用 LLM 调用 + JSON 校验 + 重试 helper
- 所有 Phase 2 workflow 脚本复用此模块,避免重复代码。
- 新增:支持 schema_name 参数,自动将 schema 传给 LLM 作为结构化输出约束。
- """
- import json
- import re
- from typing import Any, Callable, Dict, Optional, Tuple
- async def call_llm_with_retry(
- llm_call,
- messages: list,
- model: str,
- temperature: float = 0.1,
- max_tokens: int = 4000,
- max_retries: int = 3,
- validate_fn: Optional[Callable[[dict], Optional[str]]] = None,
- task_name: str = "",
- schema_name: Optional[str] = None,
- ) -> Tuple[Optional[Dict[str, Any]], float]:
- """
- 调用 LLM 并自动校验 JSON 输出,失败时重试。
- Args:
- llm_call: LLM 调用函数
- messages: 初始消息列表
- model: 模型名称
- temperature: 温度
- max_tokens: 最大 token 数
- max_retries: 最大重试次数
- validate_fn: 可选的 schema 校验函数,接收 dict 返回 error string 或 None
- task_name: 任务名称(用于日志)
- schema_name: 可选的 schema 名称,如果提供则自动加载 schema 并传给 LLM
- Returns:
- (parsed_data, total_cost) — parsed_data 为 None 表示全部重试失败
- """
- total_cost = 0.0
- last_error = None
- # 如果提供了 schema_name,加载 schema 并自动设置 validate_fn
- response_schema = None
- if schema_name:
- try:
- from examples.process_pipeline.script.schema_manager import get_schema_manager, validate_with_schema
- manager = get_schema_manager()
- response_schema = manager.get_stripped_schema(schema_name)
- if response_schema and not validate_fn:
- # 自动设置校验函数
- validate_fn = lambda data: validate_with_schema(data, schema_name)
- except Exception as e:
- if task_name:
- print(f" [{task_name}] Warning: Failed to load schema '{schema_name}': {e}")
- for attempt in range(max_retries):
- current_messages = list(messages)
- # 如果是重试,把上次的错误信息附加到消息中
- if attempt > 0 and last_error:
- if "JSON 解析失败" in last_error:
- fix_hint = (
- f"你上次的输出存在 JSON 格式错误:{last_error}\n\n"
- f"常见原因:字符串值中包含了未转义的英文双引号。\n"
- f"修复方法:所有字符串值中的英文双引号必须转义为 \\\",或改用中文引号「」。\n\n"
- f"请重新输出完整且格式正确的 JSON,不要包含任何其他内容。"
- )
- else:
- fix_hint = (
- f"你上次的输出未通过校验。错误:{last_error}\n\n"
- f"请修正后重新输出完整的 JSON,不要包含其他内容。"
- )
- current_messages.append({"role": "user", "content": fix_hint})
- if task_name:
- print(f" [{task_name}] Retry {attempt}/{max_retries-1}: {last_error[:80]}...")
- try:
- # 构建 LLM 调用参数
- call_kwargs = {
- "messages": current_messages,
- "model": model,
- "temperature": temperature,
- "max_tokens": max_tokens,
- }
- # 如果有 schema,传给 LLM(支持 response_schema 参数的 LLM 会使用)
- if response_schema:
- call_kwargs["response_schema"] = response_schema
- response = await llm_call(**call_kwargs)
- # 计算成本
- usage = response.get("usage", {})
- if hasattr(usage, "__dict__"):
- input_tokens = getattr(usage, "input_tokens", 0) or getattr(usage, "prompt_tokens", 0)
- output_tokens = getattr(usage, "output_tokens", 0) or getattr(usage, "completion_tokens", 0)
- else:
- input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
- output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
- total_cost += (input_tokens / 1e6 * 3.0) + (output_tokens / 1e6 * 15.0)
- # 提取内容
- content = response.get("content", "")
- if isinstance(content, list):
- first = content[0] if content else ""
- content = first.get("text", "") if isinstance(first, dict) else str(first)
- elif not isinstance(content, str):
- content = str(content)
- # 尝试解析 JSON
- json_match = re.search(r"\{[\s\S]*\}", content)
- if not json_match:
- last_error = "LLM 输出中未找到有效的 JSON 对象"
- continue
- raw_json = json_match.group()
- try:
- parsed_data = json.loads(raw_json)
- except json.JSONDecodeError as e:
- # 尝试自动修复 JSON 语法错误
- try:
- from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
- success, parsed_data, fix_desc = try_fix_and_parse(raw_json)
- if not success:
- last_error = f"JSON 解析失败且自动修复无效: {e}"
- print(f" [DEBUG] fix failed, raw_json:\n{raw_json}", flush=True)
- continue
- if task_name:
- print(f" [{task_name}] Auto-fixed JSON: {fix_desc}", flush=True)
- except ImportError:
- last_error = f"JSON 解析失败: {e}"
- continue
- # Schema 校验
- if validate_fn:
- schema_err = validate_fn(parsed_data)
- if schema_err:
- last_error = f"Schema 校验失败: {schema_err}"
- continue
- # 全部通过
- return parsed_data, total_cost
- except Exception as e:
- last_error = f"LLM 调用异常: {type(e).__name__}: {e}"
- if task_name:
- print(f" [{task_name}] Error: {last_error}")
- # 全部重试失败
- if task_name:
- print(f" [{task_name}] All {max_retries} attempts failed. Last error: {last_error}")
- return None, total_cost
|