llm_helper.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. """
  2. 通用 LLM 调用 + JSON 校验 + 重试 helper
  3. 所有 Phase 2 workflow 脚本复用此模块,避免重复代码。
  4. 新增:支持 schema_name 参数,自动将 schema 传给 LLM 作为结构化输出约束。
  5. """
  6. import json
  7. import re
  8. from typing import Any, Callable, Dict, Optional, Tuple
  9. async def call_llm_with_retry(
  10. llm_call,
  11. messages: list,
  12. model: str,
  13. temperature: float = 0.1,
  14. max_tokens: int = 4000,
  15. max_retries: int = 3,
  16. validate_fn: Optional[Callable[[dict], Optional[str]]] = None,
  17. task_name: str = "",
  18. schema_name: Optional[str] = None,
  19. ) -> Tuple[Optional[Dict[str, Any]], float]:
  20. """
  21. 调用 LLM 并自动校验 JSON 输出,失败时重试。
  22. Args:
  23. llm_call: LLM 调用函数
  24. messages: 初始消息列表
  25. model: 模型名称
  26. temperature: 温度
  27. max_tokens: 最大 token 数
  28. max_retries: 最大重试次数
  29. validate_fn: 可选的 schema 校验函数,接收 dict 返回 error string 或 None
  30. task_name: 任务名称(用于日志)
  31. schema_name: 可选的 schema 名称,如果提供则自动加载 schema 并传给 LLM
  32. Returns:
  33. (parsed_data, total_cost) — parsed_data 为 None 表示全部重试失败
  34. """
  35. total_cost = 0.0
  36. last_error = None
  37. # 如果提供了 schema_name,加载 schema 并自动设置 validate_fn
  38. response_schema = None
  39. if schema_name:
  40. try:
  41. from examples.process_pipeline.script.schema_manager import get_schema_manager, validate_with_schema
  42. manager = get_schema_manager()
  43. response_schema = manager.get_stripped_schema(schema_name)
  44. if response_schema and not validate_fn:
  45. # 自动设置校验函数
  46. validate_fn = lambda data: validate_with_schema(data, schema_name)
  47. except Exception as e:
  48. if task_name:
  49. print(f" [{task_name}] Warning: Failed to load schema '{schema_name}': {e}")
  50. for attempt in range(max_retries):
  51. current_messages = list(messages)
  52. # 如果是重试,把上次的错误信息附加到消息中
  53. if attempt > 0 and last_error:
  54. if "JSON 解析失败" in last_error:
  55. fix_hint = (
  56. f"你上次的输出存在 JSON 格式错误:{last_error}\n\n"
  57. f"常见原因:字符串值中包含了未转义的英文双引号。\n"
  58. f"修复方法:所有字符串值中的英文双引号必须转义为 \\\",或改用中文引号「」。\n\n"
  59. f"请重新输出完整且格式正确的 JSON,不要包含任何其他内容。"
  60. )
  61. else:
  62. fix_hint = (
  63. f"你上次的输出未通过校验。错误:{last_error}\n\n"
  64. f"请修正后重新输出完整的 JSON,不要包含其他内容。"
  65. )
  66. current_messages.append({"role": "user", "content": fix_hint})
  67. if task_name:
  68. print(f" [{task_name}] Retry {attempt}/{max_retries-1}: {last_error[:80]}...")
  69. try:
  70. # 构建 LLM 调用参数
  71. call_kwargs = {
  72. "messages": current_messages,
  73. "model": model,
  74. "temperature": temperature,
  75. "max_tokens": max_tokens,
  76. }
  77. # 如果有 schema,传给 LLM(支持 response_schema 参数的 LLM 会使用)
  78. if response_schema:
  79. call_kwargs["response_schema"] = response_schema
  80. response = await llm_call(**call_kwargs)
  81. # 计算成本:优先用 provider 自带的准确 cost(qwen / openrouter 都按各自单价算过),
  82. # 没有才回退到粗略估算(Claude 单价 $3/$15 per M tokens)——避免按 Claude 单价高估 qwen。
  83. provider_cost = response.get("cost")
  84. if isinstance(provider_cost, (int, float)) and provider_cost > 0:
  85. total_cost += provider_cost
  86. else:
  87. usage = response.get("usage", {})
  88. if hasattr(usage, "__dict__"):
  89. input_tokens = getattr(usage, "input_tokens", 0) or getattr(usage, "prompt_tokens", 0)
  90. output_tokens = getattr(usage, "output_tokens", 0) or getattr(usage, "completion_tokens", 0)
  91. else:
  92. input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
  93. output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
  94. total_cost += (input_tokens / 1e6 * 3.0) + (output_tokens / 1e6 * 15.0)
  95. # 提取内容
  96. content = response.get("content", "")
  97. if isinstance(content, list):
  98. first = content[0] if content else ""
  99. content = first.get("text", "") if isinstance(first, dict) else str(first)
  100. elif not isinstance(content, str):
  101. content = str(content)
  102. # 尝试解析 JSON
  103. json_match = re.search(r"\{[\s\S]*\}", content)
  104. if not json_match:
  105. last_error = "LLM 输出中未找到有效的 JSON 对象"
  106. continue
  107. raw_json = json_match.group()
  108. try:
  109. parsed_data = json.loads(raw_json)
  110. except json.JSONDecodeError as e:
  111. # 尝试自动修复 JSON 语法错误
  112. try:
  113. from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
  114. success, parsed_data, fix_desc = try_fix_and_parse(raw_json)
  115. if not success:
  116. last_error = f"JSON 解析失败且自动修复无效: {e}"
  117. print(f" [DEBUG] fix failed, raw_json:\n{raw_json}", flush=True)
  118. continue
  119. if task_name:
  120. print(f" [{task_name}] Auto-fixed JSON: {fix_desc}", flush=True)
  121. except ImportError:
  122. last_error = f"JSON 解析失败: {e}"
  123. continue
  124. # Schema 校验
  125. if validate_fn:
  126. schema_err = validate_fn(parsed_data)
  127. if schema_err:
  128. last_error = f"Schema 校验失败: {schema_err}"
  129. # 完整 dump LLM 输出(不截断)便于定位失败位置
  130. if task_name:
  131. print(f" [{task_name}] === SCHEMA FAIL on attempt {attempt + 1} ===", flush=True)
  132. print(f" [{task_name}] error: {schema_err}", flush=True)
  133. print(f" [{task_name}] full LLM output ({len(content)} chars):", flush=True)
  134. print(content, flush=True)
  135. print(f" [{task_name}] === end LLM output ===", flush=True)
  136. continue
  137. # 全部通过
  138. return parsed_data, total_cost
  139. except Exception as e:
  140. last_error = f"LLM 调用异常: {type(e).__name__}: {e}"
  141. if task_name:
  142. print(f" [{task_name}] Error: {last_error}")
  143. # 全部重试失败
  144. if task_name:
  145. print(f" [{task_name}] All {max_retries} attempts failed. Last error: {last_error}")
  146. return None, total_cost