|
|
@@ -0,0 +1,98 @@
|
|
|
+"""
|
|
|
+Qwen LLM provider using OpenAI SDK.
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import logging
|
|
|
+from typing import Any, Callable, Dict, List, Optional
|
|
|
+from openai import AsyncOpenAI
|
|
|
+
|
|
|
+# 这里的导入根据你的项目结构调整
|
|
|
+from .usage import TokenUsage
|
|
|
+from .pricing import PricingCalculator
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# 2026 推荐:如果 qwen3.5-plus 报 404,请先用 qwen-plus 测试
|
|
|
+# 阿里有时要求兼容模式下的 ID 必须是特定的字符串
|
|
|
+DEFAULT_QWEN_MODEL = "qwen-plus"
|
|
|
+
|
|
|
+def create_qwen_llm_call(
|
|
|
+ model: str = DEFAULT_QWEN_MODEL,
|
|
|
+ base_url: Optional[str] = None,
|
|
|
+ api_key: Optional[str] = None,
|
|
|
+) -> Callable:
|
|
|
+ """
|
|
|
+ Create a Qwen LLM call function using the OpenAI SDK.
|
|
|
+ """
|
|
|
+ # 获取配置
|
|
|
+ # 注意:使用 OpenAI SDK 时,base_url 必须包含到 /v1
|
|
|
+ api_key = api_key or os.getenv("QWEN_API_KEY")
|
|
|
+ base_url = base_url or os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
|
+
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("QWEN_API_KEY is required")
|
|
|
+
|
|
|
+ # 初始化 OpenAI 异步客户端
|
|
|
+ # SDK 会自动处理 /chat/completions 的拼接
|
|
|
+ client = AsyncOpenAI(
|
|
|
+ api_key=api_key,
|
|
|
+ base_url=base_url
|
|
|
+ )
|
|
|
+
|
|
|
+ pricing_calc = PricingCalculator()
|
|
|
+
|
|
|
+ async def llm_call(
|
|
|
+ messages: List[Dict[str, Any]],
|
|
|
+ model: str = model,
|
|
|
+ tools: Optional[List[Dict]] = None,
|
|
|
+ temperature: float = 0.7,
|
|
|
+ max_tokens: int = 4096,
|
|
|
+ **kwargs
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = await client.chat.completions.create(
|
|
|
+ model=model,
|
|
|
+ messages=messages,
|
|
|
+ tools=tools,
|
|
|
+ temperature=temperature,
|
|
|
+ max_tokens=max_tokens,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取内容
|
|
|
+ content = response.choices[0].message.content or ""
|
|
|
+
|
|
|
+ # --- 关键修正位置 ---
|
|
|
+ # 将 Pydantic 对象转换为原始 Dict 列表,这样 runner.py 的 .get() 才不会报错
|
|
|
+ tool_calls = None
|
|
|
+ if response.choices[0].message.tool_calls:
|
|
|
+ tool_calls = [
|
|
|
+ tc.model_dump() for tc in response.choices[0].message.tool_calls
|
|
|
+ ]
|
|
|
+ # ------------------
|
|
|
+
|
|
|
+ usage = TokenUsage(
|
|
|
+ input_tokens=response.usage.prompt_tokens,
|
|
|
+ output_tokens=response.usage.completion_tokens,
|
|
|
+ )
|
|
|
+
|
|
|
+ cost = pricing_calc.calculate_cost(model=model, usage=usage)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "content": content,
|
|
|
+ "tool_calls": tool_calls, # 现在这里是 List[Dict] 了
|
|
|
+ "prompt_tokens": usage.input_tokens,
|
|
|
+ "completion_tokens": usage.output_tokens,
|
|
|
+ "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0),
|
|
|
+ "finish_reason": response.choices[0].finish_reason,
|
|
|
+ "cost": cost,
|
|
|
+ "usage": usage,
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Qwen SDK Call Failed: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ return llm_call
|