| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- """
- Qwen LLM provider using OpenAI SDK.
- """
- import os
- import logging
- from typing import Any, Callable, Dict, List, Optional
- from openai import AsyncOpenAI
- # 这里的导入根据你的项目结构调整
- from .usage import TokenUsage
- from .pricing import PricingCalculator
- logger = logging.getLogger(__name__)
- # 2026 推荐:如果 qwen3.5-plus 报 404,请先用 qwen-plus 测试
- # 阿里有时要求兼容模式下的 ID 必须是特定的字符串
- DEFAULT_QWEN_MODEL = "qwen-plus"
- def create_qwen_llm_call(
- model: str = DEFAULT_QWEN_MODEL,
- base_url: Optional[str] = None,
- api_key: Optional[str] = None,
- ) -> Callable:
- """
- Create a Qwen LLM call function using the OpenAI SDK.
- """
- # 获取配置
- # 注意:使用 OpenAI SDK 时,base_url 必须包含到 /v1
- api_key = api_key or os.getenv("QWEN_API_KEY")
- base_url = base_url or os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
- if not api_key:
- raise ValueError("QWEN_API_KEY is required")
- # 初始化 OpenAI 异步客户端
- # SDK 会自动处理 /chat/completions 的拼接
- client = AsyncOpenAI(
- api_key=api_key,
- base_url=base_url
- )
- pricing_calc = PricingCalculator()
- async def llm_call(
- messages: List[Dict[str, Any]],
- model: str = model,
- tools: Optional[List[Dict]] = None,
- temperature: float = 0.2,
- max_tokens: int = 16384,
- **kwargs
- ) -> Dict[str, Any]:
-
- try:
- response = await client.chat.completions.create(
- model=model,
- messages=messages,
- tools=tools,
- temperature=temperature,
- max_tokens=max_tokens,
- **kwargs
- )
- # 获取内容
- content = response.choices[0].message.content or ""
-
- # --- 关键修正位置 ---
- # 将 Pydantic 对象转换为原始 Dict 列表,这样 runner.py 的 .get() 才不会报错
- tool_calls = None
- if response.choices[0].message.tool_calls:
- tool_calls = [
- tc.model_dump() for tc in response.choices[0].message.tool_calls
- ]
- # ------------------
- usage = TokenUsage(
- input_tokens=response.usage.prompt_tokens,
- output_tokens=response.usage.completion_tokens,
- )
- cost = pricing_calc.calculate_cost(model=model, usage=usage)
- return {
- "content": content,
- "tool_calls": tool_calls, # 现在这里是 List[Dict] 了
- "prompt_tokens": usage.input_tokens,
- "completion_tokens": usage.output_tokens,
- "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0),
- "finish_reason": response.choices[0].finish_reason,
- "cost": cost,
- "usage": usage,
- }
- except Exception as e:
- logger.error(f"Qwen SDK Call Failed: {str(e)}")
- raise
- return llm_call
- async def qwen_llm_call(
- messages: List[Dict[str, Any]],
- model: str = DEFAULT_QWEN_MODEL,
- tools: Optional[List[Dict]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- Qwen LLM 调用函数(独立函数,可直接使用)
- Args:
- messages: OpenAI 格式消息列表
- model: 模型名称(如 "qwen-plus", "qwen-turbo", "qwen-max")
- tools: OpenAI 格式工具定义
- **kwargs: 其他参数(temperature, max_tokens 等)
- Returns:
- {
- "content": str,
- "tool_calls": List[Dict] | None,
- "prompt_tokens": int,
- "completion_tokens": int,
- "reasoning_tokens": int,
- "cache_creation_tokens": int,
- "cache_read_tokens": int,
- "finish_reason": str,
- "cost": float,
- "usage": TokenUsage,
- }
- """
- import asyncio
- from .pricing import calculate_cost
- api_key = os.getenv("QWEN_API_KEY")
- if not api_key:
- raise ValueError("QWEN_API_KEY environment variable not set")
- base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
- client = AsyncOpenAI(api_key=api_key, base_url=base_url)
- # 构建请求参数
- create_kwargs = {
- "model": model,
- "messages": messages,
- }
- if tools:
- create_kwargs["tools"] = tools
- if "temperature" in kwargs:
- create_kwargs["temperature"] = kwargs["temperature"]
- if "max_tokens" in kwargs:
- create_kwargs["max_tokens"] = kwargs["max_tokens"]
- # 带重试的调用
- max_retries = 3
- last_exception = None
- for attempt in range(max_retries):
- try:
- response = await client.chat.completions.create(**create_kwargs)
- break
- except (ConnectionError, TimeoutError, OSError) as e:
- last_exception = e
- if attempt < max_retries - 1:
- wait = 2 ** attempt * 2
- logger.warning(
- "[Qwen] %s (attempt %d/%d), retrying in %ds",
- type(e).__name__, attempt + 1, max_retries, wait,
- )
- await asyncio.sleep(wait)
- continue
- logger.error("[Qwen] Request failed after %d attempts: %s", max_retries, e)
- raise
- except Exception as e:
- logger.error("[Qwen] Request failed: %s", e)
- raise
- else:
- raise last_exception # type: ignore[misc]
- # 解析响应
- choice = response.choices[0]
- content = choice.message.content or ""
- finish_reason = choice.finish_reason
- # tool_calls: Pydantic 对象转 dict
- tool_calls = None
- if choice.message.tool_calls:
- tool_calls = [tc.model_dump() for tc in choice.message.tool_calls]
- # 解析 usage
- usage = TokenUsage(
- input_tokens=response.usage.prompt_tokens,
- output_tokens=response.usage.completion_tokens,
- )
- cost = calculate_cost(model, usage)
- return {
- "content": content,
- "tool_calls": tool_calls,
- "prompt_tokens": usage.input_tokens,
- "completion_tokens": usage.output_tokens,
- "reasoning_tokens": getattr(response.usage, "reasoning_tokens", 0) or 0,
- "cache_creation_tokens": 0,
- "cache_read_tokens": 0,
- "finish_reason": finish_reason,
- "cost": cost,
- "usage": usage,
- }
|