qwen.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """
  2. Qwen LLM provider using OpenAI SDK.
  3. """
  4. import os
  5. import logging
  6. from typing import Any, Callable, Dict, List, Optional
  7. from openai import AsyncOpenAI
  8. # 这里的导入根据你的项目结构调整
  9. from .usage import TokenUsage
  10. from .pricing import PricingCalculator
  11. logger = logging.getLogger(__name__)
  12. # 2026 推荐:如果 qwen3.5-plus 报 404,请先用 qwen-plus 测试
  13. # 阿里有时要求兼容模式下的 ID 必须是特定的字符串
  14. DEFAULT_QWEN_MODEL = "qwen-plus"
  15. def create_qwen_llm_call(
  16. model: str = DEFAULT_QWEN_MODEL,
  17. base_url: Optional[str] = None,
  18. api_key: Optional[str] = None,
  19. ) -> Callable:
  20. """
  21. Create a Qwen LLM call function using the OpenAI SDK.
  22. """
  23. # 获取配置
  24. # 注意:使用 OpenAI SDK 时,base_url 必须包含到 /v1
  25. api_key = api_key or os.getenv("QWEN_API_KEY")
  26. base_url = base_url or os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  27. if not api_key:
  28. raise ValueError("QWEN_API_KEY is required")
  29. # 初始化 OpenAI 异步客户端
  30. # SDK 会自动处理 /chat/completions 的拼接
  31. client = AsyncOpenAI(
  32. api_key=api_key,
  33. base_url=base_url
  34. )
  35. pricing_calc = PricingCalculator()
  36. async def llm_call(
  37. messages: List[Dict[str, Any]],
  38. model: str = model,
  39. tools: Optional[List[Dict]] = None,
  40. temperature: float = 0.2,
  41. max_tokens: int = 16384,
  42. **kwargs
  43. ) -> Dict[str, Any]:
  44. try:
  45. response = await client.chat.completions.create(
  46. model=model,
  47. messages=messages,
  48. tools=tools,
  49. temperature=temperature,
  50. max_tokens=max_tokens,
  51. **kwargs
  52. )
  53. # 获取内容
  54. content = response.choices[0].message.content or ""
  55. # 捕获 thinking 模式的推理内容(不影响 tool_calls 解构)
  56. reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) or ""
  57. # --- 关键修正位置 ---
  58. # 将 Pydantic 对象转换为原始 Dict 列表,这样 runner.py 的 .get() 才不会报错
  59. tool_calls = None
  60. if response.choices[0].message.tool_calls:
  61. tool_calls = [
  62. tc.model_dump() for tc in response.choices[0].message.tool_calls
  63. ]
  64. # ------------------
  65. usage = TokenUsage(
  66. input_tokens=response.usage.prompt_tokens,
  67. output_tokens=response.usage.completion_tokens,
  68. )
  69. cost = pricing_calc.calculate_cost(model=model, usage=usage)
  70. return {
  71. "content": content,
  72. "reasoning_content": reasoning_content,
  73. "tool_calls": tool_calls, # 现在这里是 List[Dict] 了
  74. "prompt_tokens": usage.input_tokens,
  75. "completion_tokens": usage.output_tokens,
  76. "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0),
  77. "finish_reason": response.choices[0].finish_reason,
  78. "cost": cost,
  79. "usage": usage,
  80. }
  81. except Exception as e:
  82. logger.error(f"Qwen SDK Call Failed: {str(e)}")
  83. raise
  84. return llm_call
  85. async def qwen_llm_call(
  86. messages: List[Dict[str, Any]],
  87. model: str = DEFAULT_QWEN_MODEL,
  88. tools: Optional[List[Dict]] = None,
  89. **kwargs
  90. ) -> Dict[str, Any]:
  91. """
  92. Qwen LLM 调用函数(独立函数,可直接使用)
  93. Args:
  94. messages: OpenAI 格式消息列表
  95. model: 模型名称(如 "qwen-plus", "qwen-turbo", "qwen-max")
  96. tools: OpenAI 格式工具定义
  97. **kwargs: 其他参数(temperature, max_tokens 等)
  98. Returns:
  99. {
  100. "content": str,
  101. "tool_calls": List[Dict] | None,
  102. "prompt_tokens": int,
  103. "completion_tokens": int,
  104. "reasoning_tokens": int,
  105. "cache_creation_tokens": int,
  106. "cache_read_tokens": int,
  107. "finish_reason": str,
  108. "cost": float,
  109. "usage": TokenUsage,
  110. }
  111. """
  112. import asyncio
  113. from .pricing import calculate_cost
  114. api_key = os.getenv("QWEN_API_KEY")
  115. if not api_key:
  116. raise ValueError("QWEN_API_KEY environment variable not set")
  117. base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  118. client = AsyncOpenAI(api_key=api_key, base_url=base_url)
  119. # 构建请求参数
  120. create_kwargs = {
  121. "model": model,
  122. "messages": messages,
  123. }
  124. if tools:
  125. create_kwargs["tools"] = tools
  126. if "temperature" in kwargs:
  127. create_kwargs["temperature"] = kwargs["temperature"]
  128. if "max_tokens" in kwargs:
  129. create_kwargs["max_tokens"] = kwargs["max_tokens"]
  130. # 带重试的调用
  131. max_retries = 3
  132. last_exception = None
  133. for attempt in range(max_retries):
  134. try:
  135. response = await client.chat.completions.create(**create_kwargs)
  136. break
  137. except (ConnectionError, TimeoutError, OSError) as e:
  138. last_exception = e
  139. if attempt < max_retries - 1:
  140. wait = 2 ** attempt * 2
  141. logger.warning(
  142. "[Qwen] %s (attempt %d/%d), retrying in %ds",
  143. type(e).__name__, attempt + 1, max_retries, wait,
  144. )
  145. await asyncio.sleep(wait)
  146. continue
  147. logger.error("[Qwen] Request failed after %d attempts: %s", max_retries, e)
  148. raise
  149. except Exception as e:
  150. logger.error("[Qwen] Request failed: %s", e)
  151. raise
  152. else:
  153. raise last_exception # type: ignore[misc]
  154. # 解析响应
  155. choice = response.choices[0]
  156. content = choice.message.content or ""
  157. finish_reason = choice.finish_reason
  158. # tool_calls: Pydantic 对象转 dict
  159. tool_calls = None
  160. if choice.message.tool_calls:
  161. tool_calls = [tc.model_dump() for tc in choice.message.tool_calls]
  162. # 解析 usage
  163. usage = TokenUsage(
  164. input_tokens=response.usage.prompt_tokens,
  165. output_tokens=response.usage.completion_tokens,
  166. )
  167. cost = calculate_cost(model, usage)
  168. return {
  169. "content": content,
  170. "tool_calls": tool_calls,
  171. "prompt_tokens": usage.input_tokens,
  172. "completion_tokens": usage.output_tokens,
  173. "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0) or 0,
  174. "cache_creation_tokens": 0,
  175. "cache_read_tokens": 0,
  176. "finish_reason": finish_reason,
  177. "cost": cost,
  178. "usage": usage,
  179. }