llm_helper.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """
  2. 通用 LLM 调用 + JSON 校验 + 重试 helper
  3. 所有 Phase 2 workflow 脚本复用此模块,避免重复代码。
  4. 新增:支持 schema_name 参数,自动将 schema 传给 LLM 作为结构化输出约束。
  5. """
  6. import json
  7. import re
  8. from typing import Any, Callable, Dict, Optional, Tuple
  9. async def call_llm_with_retry(
  10. llm_call,
  11. messages: list,
  12. model: str,
  13. temperature: float = 0.1,
  14. max_tokens: int = 4000,
  15. max_retries: int = 3,
  16. validate_fn: Optional[Callable[[dict], Optional[str]]] = None,
  17. task_name: str = "",
  18. schema_name: Optional[str] = None,
  19. ) -> Tuple[Optional[Dict[str, Any]], float]:
  20. """
  21. 调用 LLM 并自动校验 JSON 输出,失败时重试。
  22. Args:
  23. llm_call: LLM 调用函数
  24. messages: 初始消息列表
  25. model: 模型名称
  26. temperature: 温度
  27. max_tokens: 最大 token 数
  28. max_retries: 最大重试次数
  29. validate_fn: 可选的 schema 校验函数,接收 dict 返回 error string 或 None
  30. task_name: 任务名称(用于日志)
  31. schema_name: 可选的 schema 名称,如果提供则自动加载 schema 并传给 LLM
  32. Returns:
  33. (parsed_data, total_cost) — parsed_data 为 None 表示全部重试失败
  34. """
  35. total_cost = 0.0
  36. last_error = None
  37. # 如果提供了 schema_name,加载 schema 并自动设置 validate_fn
  38. response_schema = None
  39. if schema_name:
  40. try:
  41. from examples.process_pipeline.script.schema_manager import get_schema_manager, validate_with_schema
  42. manager = get_schema_manager()
  43. response_schema = manager.get_stripped_schema(schema_name)
  44. if response_schema and not validate_fn:
  45. # 自动设置校验函数
  46. validate_fn = lambda data: validate_with_schema(data, schema_name)
  47. except Exception as e:
  48. if task_name:
  49. print(f" [{task_name}] Warning: Failed to load schema '{schema_name}': {e}")
  50. for attempt in range(max_retries):
  51. current_messages = list(messages)
  52. # 如果是重试,把上次的错误信息附加到消息中
  53. if attempt > 0 and last_error:
  54. if "JSON 解析失败" in last_error:
  55. fix_hint = (
  56. f"你上次的输出存在 JSON 格式错误:{last_error}\n\n"
  57. f"常见原因:字符串值中包含了未转义的英文双引号。\n"
  58. f"修复方法:所有字符串值中的英文双引号必须转义为 \\\",或改用中文引号「」。\n\n"
  59. f"请重新输出完整且格式正确的 JSON,不要包含任何其他内容。"
  60. )
  61. else:
  62. fix_hint = (
  63. f"你上次的输出未通过校验。错误:{last_error}\n\n"
  64. f"请修正后重新输出完整的 JSON,不要包含其他内容。"
  65. )
  66. current_messages.append({"role": "user", "content": fix_hint})
  67. if task_name:
  68. print(f" [{task_name}] Retry {attempt}/{max_retries-1}: {last_error[:80]}...")
  69. try:
  70. # 构建 LLM 调用参数
  71. call_kwargs = {
  72. "messages": current_messages,
  73. "model": model,
  74. "temperature": temperature,
  75. "max_tokens": max_tokens,
  76. }
  77. # 如果有 schema,传给 LLM(支持 response_schema 参数的 LLM 会使用)
  78. if response_schema:
  79. call_kwargs["response_schema"] = response_schema
  80. response = await llm_call(**call_kwargs)
  81. # 计算成本
  82. usage = response.get("usage", {})
  83. if hasattr(usage, "__dict__"):
  84. input_tokens = getattr(usage, "input_tokens", 0) or getattr(usage, "prompt_tokens", 0)
  85. output_tokens = getattr(usage, "output_tokens", 0) or getattr(usage, "completion_tokens", 0)
  86. else:
  87. input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
  88. output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
  89. total_cost += (input_tokens / 1e6 * 3.0) + (output_tokens / 1e6 * 15.0)
  90. # 提取内容
  91. content = response.get("content", "")
  92. if isinstance(content, list):
  93. first = content[0] if content else ""
  94. content = first.get("text", "") if isinstance(first, dict) else str(first)
  95. elif not isinstance(content, str):
  96. content = str(content)
  97. # 尝试解析 JSON
  98. json_match = re.search(r"\{[\s\S]*\}", content)
  99. if not json_match:
  100. last_error = "LLM 输出中未找到有效的 JSON 对象"
  101. continue
  102. raw_json = json_match.group()
  103. try:
  104. parsed_data = json.loads(raw_json)
  105. except json.JSONDecodeError as e:
  106. # 尝试自动修复 JSON 语法错误
  107. try:
  108. from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
  109. success, parsed_data, fix_desc = try_fix_and_parse(raw_json)
  110. if not success:
  111. last_error = f"JSON 解析失败且自动修复无效: {e}"
  112. print(f" [DEBUG] fix failed, raw_json:\n{raw_json}", flush=True)
  113. continue
  114. if task_name:
  115. print(f" [{task_name}] Auto-fixed JSON: {fix_desc}", flush=True)
  116. except ImportError:
  117. last_error = f"JSON 解析失败: {e}"
  118. continue
  119. # Schema 校验
  120. if validate_fn:
  121. schema_err = validate_fn(parsed_data)
  122. if schema_err:
  123. last_error = f"Schema 校验失败: {schema_err}"
  124. continue
  125. # 全部通过
  126. return parsed_data, total_cost
  127. except Exception as e:
  128. last_error = f"LLM 调用异常: {type(e).__name__}: {e}"
  129. if task_name:
  130. print(f" [{task_name}] Error: {last_error}")
  131. # 全部重试失败
  132. if task_name:
  133. print(f" [{task_name}] All {max_retries} attempts failed. Last error: {last_error}")
  134. return None, total_cost