Przeglądaj źródła

fix: tokens统计、tokens计价

tanjingyu 3 tygodni temu
rodzic
commit
602252c86e
5 zmienionych plików z 165 dodań i 20 usunięć
  1. 21 1
      agent/llm/__init__.py
  2. 14 6
      agent/llm/gemini.py
  3. 72 9
      agent/llm/openrouter.py
  4. 46 2
      agent/trace/models.py
  5. 12 2
      agent/trace/store.py

+ 21 - 1
agent/llm/__init__.py

@@ -6,5 +6,25 @@ LLM Providers
 
 from .gemini import create_gemini_llm_call
 from .openrouter import create_openrouter_llm_call
+from .usage import TokenUsage, TokenUsageAccumulator, create_usage_from_response
+from .pricing import (
+    ModelPricing,
+    PricingCalculator,
+    get_pricing_calculator,
+    calculate_cost,
+)
 
-__all__ = ["create_gemini_llm_call", "create_openrouter_llm_call"]
+__all__ = [
+    # Providers
+    "create_gemini_llm_call",
+    "create_openrouter_llm_call",
+    # Usage
+    "TokenUsage",
+    "TokenUsageAccumulator",
+    "create_usage_from_response",
+    # Pricing
+    "ModelPricing",
+    "PricingCalculator",
+    "get_pricing_calculator",
+    "calculate_cost",
+]

+ 14 - 6
agent/llm/gemini.py

@@ -12,6 +12,9 @@ import sys
 import httpx
 from typing import List, Dict, Any, Optional
 
+from .usage import TokenUsage
+from .pricing import calculate_cost
+
 
 def _dump_llm_request(endpoint: str, payload: Dict[str, Any], model: str):
     """
@@ -430,18 +433,23 @@ def create_gemini_llm_call(
                             }
                         })
 
-        # 提取 usage
+        # 提取 usage(完整版)
         usage_meta = gemini_resp.get("usageMetadata", {})
-        prompt_tokens = usage_meta.get("promptTokenCount", 0)
-        completion_tokens = usage_meta.get("candidatesTokenCount", 0)
+        usage = TokenUsage.from_gemini(usage_meta)
+
+        # 计算费用
+        cost = calculate_cost(model, usage)
 
         return {
             "content": content,
             "tool_calls": tool_calls,
-            "prompt_tokens": prompt_tokens,
-            "completion_tokens": completion_tokens,
+            "prompt_tokens": usage.input_tokens,
+            "completion_tokens": usage.output_tokens,
+            "reasoning_tokens": usage.reasoning_tokens,
+            "cached_content_tokens": usage.cached_content_tokens,
             "finish_reason": finish_reason,
-            "cost": 0.0
+            "cost": cost,
+            "usage": usage,  # 完整的 TokenUsage 对象
         }
 
     return gemini_llm_call

+ 72 - 9
agent/llm/openrouter.py

@@ -3,6 +3,11 @@ OpenRouter Provider
 
 使用 OpenRouter API 调用各种模型(包括 Claude Sonnet 4.5)
 支持 OpenAI 兼容的 API 格式
+
+OpenRouter 转发多种模型,需要根据实际模型处理不同的 usage 格式:
+- OpenAI 模型: prompt_tokens, completion_tokens, completion_tokens_details.reasoning_tokens
+- Claude 模型: input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens
+- DeepSeek 模型: prompt_tokens, completion_tokens, reasoning_tokens
 """
 
 import os
@@ -10,6 +15,61 @@ import json
 import httpx
 from typing import List, Dict, Any, Optional
 
+from .usage import TokenUsage, create_usage_from_response
+from .pricing import calculate_cost
+
+
+def _detect_provider_from_model(model: str) -> str:
+    """根据模型名称检测提供商"""
+    model_lower = model.lower()
+    if model_lower.startswith("anthropic/") or "claude" in model_lower:
+        return "anthropic"
+    elif model_lower.startswith("openai/") or model_lower.startswith("gpt") or model_lower.startswith("o1") or model_lower.startswith("o3"):
+        return "openai"
+    elif model_lower.startswith("deepseek/") or "deepseek" in model_lower:
+        return "deepseek"
+    elif model_lower.startswith("google/") or "gemini" in model_lower:
+        return "gemini"
+    else:
+        return "openai"  # 默认使用 OpenAI 格式
+
+
+def _parse_openrouter_usage(usage: Dict[str, Any], model: str) -> TokenUsage:
+    """
+    解析 OpenRouter 返回的 usage
+
+    OpenRouter 会根据底层模型返回不同格式的 usage
+    """
+    provider = _detect_provider_from_model(model)
+
+    # OpenRouter 通常返回 OpenAI 格式,但可能包含额外字段
+    if provider == "anthropic":
+        # Claude 模型可能有缓存字段
+        return TokenUsage(
+            input_tokens=usage.get("prompt_tokens") or usage.get("input_tokens", 0),
+            output_tokens=usage.get("completion_tokens") or usage.get("output_tokens", 0),
+            cache_creation_tokens=usage.get("cache_creation_input_tokens", 0),
+            cache_read_tokens=usage.get("cache_read_input_tokens", 0),
+        )
+    elif provider == "deepseek":
+        # DeepSeek 可能有 reasoning_tokens
+        return TokenUsage(
+            input_tokens=usage.get("prompt_tokens", 0),
+            output_tokens=usage.get("completion_tokens", 0),
+            reasoning_tokens=usage.get("reasoning_tokens", 0),
+        )
+    else:
+        # OpenAI 格式(包括 o1/o3 的 reasoning_tokens)
+        reasoning = 0
+        if details := usage.get("completion_tokens_details"):
+            reasoning = details.get("reasoning_tokens", 0)
+
+        return TokenUsage(
+            input_tokens=usage.get("prompt_tokens", 0),
+            output_tokens=usage.get("completion_tokens", 0),
+            reasoning_tokens=reasoning,
+        )
+
 
 async def openrouter_llm_call(
     messages: List[Dict[str, Any]],
@@ -88,21 +148,24 @@ async def openrouter_llm_call(
     tool_calls = message.get("tool_calls")
     finish_reason = choice.get("finish_reason")  # stop, length, tool_calls, content_filter 等
 
-    # 提取 usage
-    usage = result.get("usage", {})
-    prompt_tokens = usage.get("prompt_tokens", 0)
-    completion_tokens = usage.get("completion_tokens", 0)
+    # 提取 usage(完整版,根据模型类型解析)
+    raw_usage = result.get("usage", {})
+    usage = _parse_openrouter_usage(raw_usage, model)
 
-    # 计算成本(OpenRouter 通常在响应中提供,但这里简化为 0)
-    cost = 0.0
+    # 计算费用
+    cost = calculate_cost(model, usage)
 
     return {
         "content": content,
         "tool_calls": tool_calls,
-        "prompt_tokens": prompt_tokens,
-        "completion_tokens": completion_tokens,
+        "prompt_tokens": usage.input_tokens,
+        "completion_tokens": usage.output_tokens,
+        "reasoning_tokens": usage.reasoning_tokens,
+        "cache_creation_tokens": usage.cache_creation_tokens,
+        "cache_read_tokens": usage.cache_read_tokens,
         "finish_reason": finish_reason,
-        "cost": cost
+        "cost": cost,
+        "usage": usage,  # 完整的 TokenUsage 对象
     }
 
 

+ 46 - 2
agent/trace/models.py

@@ -10,6 +10,11 @@ from datetime import datetime
 from typing import Dict, Any, List, Optional, Literal
 import uuid
 
+# 导入 TokenUsage(延迟导入避免循环依赖)
+def _get_token_usage_class():
+    from ..llm.usage import TokenUsage
+    return TokenUsage
+
 
 @dataclass
 class Trace:
@@ -44,6 +49,9 @@ class Trace:
     total_tokens: int = 0        # 总 tokens(向后兼容,= prompt + completion)
     total_prompt_tokens: int = 0      # 总输入 tokens
     total_completion_tokens: int = 0  # 总输出 tokens
+    total_reasoning_tokens: int = 0   # 总推理 tokens(o1/o3, DeepSeek R1, Gemini thinking)
+    total_cache_creation_tokens: int = 0  # 总缓存创建 tokens(Claude)
+    total_cache_read_tokens: int = 0      # 总缓存读取 tokens(Claude)
     total_cost: float = 0.0
     total_duration_ms: int = 0   # 总耗时(毫秒)
 
@@ -97,6 +105,9 @@ class Trace:
             "total_tokens": self.total_tokens,
             "total_prompt_tokens": self.total_prompt_tokens,
             "total_completion_tokens": self.total_completion_tokens,
+            "total_reasoning_tokens": self.total_reasoning_tokens,
+            "total_cache_creation_tokens": self.total_cache_creation_tokens,
+            "total_cache_read_tokens": self.total_cache_read_tokens,
             "total_cost": self.total_cost,
             "total_duration_ms": self.total_duration_ms,
             "last_sequence": self.last_sequence,
@@ -139,6 +150,9 @@ class Message:
     # 元数据
     prompt_tokens: Optional[int] = None  # 输入 tokens
     completion_tokens: Optional[int] = None  # 输出 tokens
+    reasoning_tokens: Optional[int] = None   # 推理 tokens(o1/o3, DeepSeek R1, Gemini thinking)
+    cache_creation_tokens: Optional[int] = None  # 缓存创建 tokens(Claude)
+    cache_read_tokens: Optional[int] = None      # 缓存读取 tokens(Claude)
     cost: Optional[float] = None
     duration_ms: Optional[int] = None
     created_at: datetime = field(default_factory=datetime.now)
@@ -148,9 +162,25 @@ class Message:
 
     @property
     def tokens(self) -> int:
-        """动态计算总 tokens(向后兼容)"""
+        """动态计算总 tokens(向后兼容,input + output)"""
         return (self.prompt_tokens or 0) + (self.completion_tokens or 0)
 
+    @property
+    def all_tokens(self) -> int:
+        """所有 tokens(包括 reasoning)"""
+        return self.tokens + (self.reasoning_tokens or 0)
+
+    def get_usage(self):
+        """获取 TokenUsage 对象"""
+        TokenUsage = _get_token_usage_class()
+        return TokenUsage(
+            input_tokens=self.prompt_tokens or 0,
+            output_tokens=self.completion_tokens or 0,
+            reasoning_tokens=self.reasoning_tokens or 0,
+            cache_creation_tokens=self.cache_creation_tokens or 0,
+            cache_read_tokens=self.cache_read_tokens or 0,
+        )
+
     @classmethod
     def from_dict(cls, data: Dict[str, Any]) -> "Message":
         """从字典创建 Message(处理向后兼容)"""
@@ -174,6 +204,9 @@ class Message:
         tool_call_id: Optional[str] = None,
         prompt_tokens: Optional[int] = None,
         completion_tokens: Optional[int] = None,
+        reasoning_tokens: Optional[int] = None,
+        cache_creation_tokens: Optional[int] = None,
+        cache_read_tokens: Optional[int] = None,
         cost: Optional[float] = None,
         duration_ms: Optional[int] = None,
         finish_reason: Optional[str] = None,
@@ -192,6 +225,9 @@ class Message:
             tool_call_id=tool_call_id,
             prompt_tokens=prompt_tokens,
             completion_tokens=completion_tokens,
+            reasoning_tokens=reasoning_tokens,
+            cache_creation_tokens=cache_creation_tokens,
+            cache_read_tokens=cache_read_tokens,
             cost=cost,
             duration_ms=duration_ms,
             finish_reason=finish_reason,
@@ -261,7 +297,7 @@ class Message:
 
     def to_dict(self) -> Dict[str, Any]:
         """转换为字典"""
-        return {
+        result = {
             "message_id": self.message_id,
             "trace_id": self.trace_id,
             "role": self.role,
@@ -278,6 +314,14 @@ class Message:
             "finish_reason": self.finish_reason,
             "created_at": self.created_at.isoformat() if self.created_at else None,
         }
+        # 只添加非空的可选字段
+        if self.reasoning_tokens:
+            result["reasoning_tokens"] = self.reasoning_tokens
+        if self.cache_creation_tokens:
+            result["cache_creation_tokens"] = self.cache_creation_tokens
+        if self.cache_read_tokens:
+            result["cache_read_tokens"] = self.cache_read_tokens
+        return result
 
 
 # ===== 已弃用:Step 模型(保留用于向后兼容)=====

+ 12 - 2
agent/trace/store.py

@@ -316,11 +316,18 @@ class FileSystemTraceStore:
             trace.total_messages += 1
             trace.last_sequence = max(trace.last_sequence, message.sequence)
 
-            # 累计 tokens(拆分
+            # 累计 tokens(完整版
             if message.prompt_tokens:
                 trace.total_prompt_tokens += message.prompt_tokens
             if message.completion_tokens:
                 trace.total_completion_tokens += message.completion_tokens
+            if message.reasoning_tokens:
+                trace.total_reasoning_tokens += message.reasoning_tokens
+            if message.cache_creation_tokens:
+                trace.total_cache_creation_tokens += message.cache_creation_tokens
+            if message.cache_read_tokens:
+                trace.total_cache_read_tokens += message.cache_read_tokens
+
             # 向后兼容:也更新 total_tokens
             if message.tokens:
                 trace.total_tokens += message.tokens
@@ -332,7 +339,7 @@ class FileSystemTraceStore:
             if message.duration_ms:
                 trace.total_duration_ms += message.duration_ms
 
-            # 更新 Trace(不要传递 trace_id,它已经在方法参数中)
+            # 更新 Trace
             await self.update_trace(
                 trace_id,
                 total_messages=trace.total_messages,
@@ -340,6 +347,9 @@ class FileSystemTraceStore:
                 total_tokens=trace.total_tokens,
                 total_prompt_tokens=trace.total_prompt_tokens,
                 total_completion_tokens=trace.total_completion_tokens,
+                total_reasoning_tokens=trace.total_reasoning_tokens,
+                total_cache_creation_tokens=trace.total_cache_creation_tokens,
+                total_cache_read_tokens=trace.total_cache_read_tokens,
                 total_cost=trace.total_cost,
                 total_duration_ms=trace.total_duration_ms
             )