qwen.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """
  2. Qwen LLM provider using OpenAI SDK.
  3. """
  4. import os
  5. import logging
  6. from typing import Any, Callable, Dict, List, Optional
  7. from openai import AsyncOpenAI
  8. # 这里的导入根据你的项目结构调整
  9. from .usage import TokenUsage
  10. from .pricing import PricingCalculator
  11. logger = logging.getLogger(__name__)
  12. # 2026 推荐:如果 qwen3.5-plus 报 404,请先用 qwen-plus 测试
  13. # 阿里有时要求兼容模式下的 ID 必须是特定的字符串
  14. DEFAULT_QWEN_MODEL = "qwen-plus"
  15. def create_qwen_llm_call(
  16. model: str = DEFAULT_QWEN_MODEL,
  17. base_url: Optional[str] = None,
  18. api_key: Optional[str] = None,
  19. ) -> Callable:
  20. """
  21. Create a Qwen LLM call function using the OpenAI SDK.
  22. """
  23. # 获取配置
  24. # 注意:使用 OpenAI SDK 时,base_url 必须包含到 /v1
  25. api_key = api_key or os.getenv("QWEN_API_KEY")
  26. base_url = base_url or os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  27. if not api_key:
  28. raise ValueError("QWEN_API_KEY is required")
  29. # 初始化 OpenAI 异步客户端
  30. # SDK 会自动处理 /chat/completions 的拼接
  31. client = AsyncOpenAI(
  32. api_key=api_key,
  33. base_url=base_url
  34. )
  35. pricing_calc = PricingCalculator()
  36. async def llm_call(
  37. messages: List[Dict[str, Any]],
  38. model: str = model,
  39. tools: Optional[List[Dict]] = None,
  40. temperature: float = 0.7,
  41. max_tokens: int = 4096,
  42. **kwargs
  43. ) -> Dict[str, Any]:
  44. try:
  45. response = await client.chat.completions.create(
  46. model=model,
  47. messages=messages,
  48. tools=tools,
  49. temperature=temperature,
  50. max_tokens=max_tokens,
  51. **kwargs
  52. )
  53. # 获取内容
  54. content = response.choices[0].message.content or ""
  55. # --- 关键修正位置 ---
  56. # 将 Pydantic 对象转换为原始 Dict 列表,这样 runner.py 的 .get() 才不会报错
  57. tool_calls = None
  58. if response.choices[0].message.tool_calls:
  59. tool_calls = [
  60. tc.model_dump() for tc in response.choices[0].message.tool_calls
  61. ]
  62. # ------------------
  63. usage = TokenUsage(
  64. input_tokens=response.usage.prompt_tokens,
  65. output_tokens=response.usage.completion_tokens,
  66. )
  67. cost = pricing_calc.calculate_cost(model=model, usage=usage)
  68. return {
  69. "content": content,
  70. "tool_calls": tool_calls, # 现在这里是 List[Dict] 了
  71. "prompt_tokens": usage.input_tokens,
  72. "completion_tokens": usage.output_tokens,
  73. "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0),
  74. "finish_reason": response.choices[0].finish_reason,
  75. "cost": cost,
  76. "usage": usage,
  77. }
  78. except Exception as e:
  79. logger.error(f"Qwen SDK Call Failed: {str(e)}")
  80. raise
  81. return llm_call