qwen.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. Qwen LLM provider using OpenAI SDK.
  3. """
  4. import os
  5. import logging
  6. from typing import Any, Callable, Dict, List, Optional
  7. from openai import AsyncOpenAI
  8. # 这里的导入根据你的项目结构调整
  9. from .usage import TokenUsage
  10. from .pricing import PricingCalculator
  11. logger = logging.getLogger(__name__)
  12. # 2026 推荐:如果 qwen3.5-plus 报 404,请先用 qwen-plus 测试
  13. # 阿里有时要求兼容模式下的 ID 必须是特定的字符串
  14. DEFAULT_QWEN_MODEL = "qwen-plus"
  15. def create_qwen_llm_call(
  16. model: str = DEFAULT_QWEN_MODEL,
  17. base_url: Optional[str] = None,
  18. api_key: Optional[str] = None,
  19. ) -> Callable:
  20. """
  21. Create a Qwen LLM call function using the OpenAI SDK.
  22. """
  23. # 获取配置
  24. # 注意:使用 OpenAI SDK 时,base_url 必须包含到 /v1
  25. api_key = api_key or os.getenv("QWEN_API_KEY")
  26. base_url = base_url or os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  27. if not api_key:
  28. raise ValueError("QWEN_API_KEY is required")
  29. # 初始化 OpenAI 异步客户端
  30. # SDK 会自动处理 /chat/completions 的拼接
  31. client = AsyncOpenAI(
  32. api_key=api_key,
  33. base_url=base_url
  34. )
  35. pricing_calc = PricingCalculator()
  36. async def llm_call(
  37. messages: List[Dict[str, Any]],
  38. model: str = model,
  39. tools: Optional[List[Dict]] = None,
  40. temperature: float = 0.2,
  41. max_tokens: int = 16384,
  42. **kwargs
  43. ) -> Dict[str, Any]:
  44. try:
  45. response = await client.chat.completions.create(
  46. model=model,
  47. messages=messages,
  48. tools=tools,
  49. temperature=temperature,
  50. max_tokens=max_tokens,
  51. **kwargs
  52. )
  53. # 获取内容
  54. content = response.choices[0].message.content or ""
  55. # --- 关键修正位置 ---
  56. # 将 Pydantic 对象转换为原始 Dict 列表,这样 runner.py 的 .get() 才不会报错
  57. tool_calls = None
  58. if response.choices[0].message.tool_calls:
  59. tool_calls = [
  60. tc.model_dump() for tc in response.choices[0].message.tool_calls
  61. ]
  62. # ------------------
  63. usage = TokenUsage(
  64. input_tokens=response.usage.prompt_tokens,
  65. output_tokens=response.usage.completion_tokens,
  66. )
  67. cost = pricing_calc.calculate_cost(model=model, usage=usage)
  68. return {
  69. "content": content,
  70. "tool_calls": tool_calls, # 现在这里是 List[Dict] 了
  71. "prompt_tokens": usage.input_tokens,
  72. "completion_tokens": usage.output_tokens,
  73. "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0),
  74. "finish_reason": response.choices[0].finish_reason,
  75. "cost": cost,
  76. "usage": usage,
  77. }
  78. except Exception as e:
  79. logger.error(f"Qwen SDK Call Failed: {str(e)}")
  80. raise
  81. return llm_call
  82. async def qwen_llm_call(
  83. messages: List[Dict[str, Any]],
  84. model: str = DEFAULT_QWEN_MODEL,
  85. tools: Optional[List[Dict]] = None,
  86. **kwargs
  87. ) -> Dict[str, Any]:
  88. """
  89. Qwen LLM 调用函数(独立函数,可直接使用)
  90. Args:
  91. messages: OpenAI 格式消息列表
  92. model: 模型名称(如 "qwen-plus", "qwen-turbo", "qwen-max")
  93. tools: OpenAI 格式工具定义
  94. **kwargs: 其他参数(temperature, max_tokens 等)
  95. Returns:
  96. {
  97. "content": str,
  98. "tool_calls": List[Dict] | None,
  99. "prompt_tokens": int,
  100. "completion_tokens": int,
  101. "reasoning_tokens": int,
  102. "cache_creation_tokens": int,
  103. "cache_read_tokens": int,
  104. "finish_reason": str,
  105. "cost": float,
  106. "usage": TokenUsage,
  107. }
  108. """
  109. import asyncio
  110. from .pricing import calculate_cost
  111. api_key = os.getenv("QWEN_API_KEY")
  112. if not api_key:
  113. raise ValueError("QWEN_API_KEY environment variable not set")
  114. base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  115. client = AsyncOpenAI(api_key=api_key, base_url=base_url)
  116. # 构建请求参数
  117. create_kwargs = {
  118. "model": model,
  119. "messages": messages,
  120. }
  121. if tools:
  122. create_kwargs["tools"] = tools
  123. if "temperature" in kwargs:
  124. create_kwargs["temperature"] = kwargs["temperature"]
  125. if "max_tokens" in kwargs:
  126. create_kwargs["max_tokens"] = kwargs["max_tokens"]
  127. # 带重试的调用
  128. max_retries = 3
  129. last_exception = None
  130. for attempt in range(max_retries):
  131. try:
  132. response = await client.chat.completions.create(**create_kwargs)
  133. break
  134. except (ConnectionError, TimeoutError, OSError) as e:
  135. last_exception = e
  136. if attempt < max_retries - 1:
  137. wait = 2 ** attempt * 2
  138. logger.warning(
  139. "[Qwen] %s (attempt %d/%d), retrying in %ds",
  140. type(e).__name__, attempt + 1, max_retries, wait,
  141. )
  142. await asyncio.sleep(wait)
  143. continue
  144. logger.error("[Qwen] Request failed after %d attempts: %s", max_retries, e)
  145. raise
  146. except Exception as e:
  147. logger.error("[Qwen] Request failed: %s", e)
  148. raise
  149. else:
  150. raise last_exception # type: ignore[misc]
  151. # 解析响应
  152. choice = response.choices[0]
  153. content = choice.message.content or ""
  154. finish_reason = choice.finish_reason
  155. # tool_calls: Pydantic 对象转 dict
  156. tool_calls = None
  157. if choice.message.tool_calls:
  158. tool_calls = [tc.model_dump() for tc in choice.message.tool_calls]
  159. # 解析 usage
  160. usage = TokenUsage(
  161. input_tokens=response.usage.prompt_tokens,
  162. output_tokens=response.usage.completion_tokens,
  163. )
  164. cost = calculate_cost(model, usage)
  165. return {
  166. "content": content,
  167. "tool_calls": tool_calls,
  168. "prompt_tokens": usage.input_tokens,
  169. "completion_tokens": usage.output_tokens,
  170. "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0) or 0,
  171. "cache_creation_tokens": 0,
  172. "cache_read_tokens": 0,
  173. "finish_reason": finish_reason,
  174. "cost": cost,
  175. "usage": usage,
  176. }