Browse Source

Calculate agent cost

StrayWarrior 2 days ago
parent
commit
8b0f185d52
3 changed files with 109 additions and 20 deletions
  1. 15 0
      pqai_agent/agents/simple_chat_agent.py
  2. 92 20
      pqai_agent/chat_service.py
  3. 2 0
      pqai_agent/push_service.py

+ 15 - 0
pqai_agent/agents/simple_chat_agent.py

@@ -23,6 +23,8 @@ class SimpleOpenAICompatibleChatAgent:
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
+        self.total_input_tokens = 0
+        self.total_output_tokens = 0
         logger.debug(self.tool_map)
 
     def add_tool(self, tool: FunctionTool):
@@ -42,6 +44,8 @@ class SimpleOpenAICompatibleChatAgent:
         while n_steps < self.max_run_step:
             response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
             message = response.choices[0].message
+            self.total_input_tokens += response.usage.prompt_tokens
+            self.total_output_tokens += response.usage.completion_tokens
             messages.append(message)
             logger.debug(f"current step content: {message.content}")
 
@@ -71,3 +75,14 @@ class SimpleOpenAICompatibleChatAgent:
             n_steps += 1
 
         raise Exception("Max run steps exceeded")
+
+    def get_total_input_tokens(self) -> int:
+        """获取总输入token数"""
+        return self.total_input_tokens
+
+    def get_total_output_tokens(self) -> int:
+        """获取总输出token数"""
+        return self.total_output_tokens
+
+    def get_total_cost(self) -> float:
+        return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)

+ 92 - 20
pqai_agent/chat_service.py

@@ -37,35 +37,81 @@ OPENAI_MODEL_GPT_4o_mini = 'gpt-4o-mini'
 OPENROUTER_API_TOKEN = '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'
 OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1/'
 OPENROUTER_MODEL_CLAUDE_3_7_SONNET = 'anthropic/claude-3.7-sonnet'
+ALIYUN_API_TOKEN = 'sk-47381479425f4485af7673d3d2fd92b6'
+ALIYUN_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
+
 
 class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
 
+class ModelPrice:
+    EXCHANGE_RATE_TO_CNY = {
+        "USD": 7.2,  # Example conversion rate, adjust as needed
+    }
+
+    def __init__(self, input_price: float, output_price: float, currency: str = 'CNY'):
+        """
+        :param input_price: input price for per million tokens
+        :param output_price: output price for per million tokens
+        """
+        self.input_price = input_price
+        self.output_price = output_price
+        self.currency = currency
+
+    def get_total_cost(self, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the total cost based on input and output tokens.
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the specified currency
+        """
+        total_cost = (self.input_price * input_tokens / 1_000_000) + (self.output_price * output_tokens / 1_000_000)
+        if convert_to_cny and self.currency != 'CNY':
+            conversion_rate = self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
+            total_cost *= conversion_rate
+        return total_cost
+
+    def __repr__(self):
+        return f"ModelPrice(input_price={self.input_price}, output_price={self.output_price}, currency={self.currency})"
+
 class OpenAICompatible:
+    volcengine_models = [
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+        VOLCENGINE_MODEL_DEEPSEEK_V3
+    ]
+    deepseek_models = [
+        DEEPSEEK_CHAT_MODEL,
+    ]
+    openai_models = [
+        OPENAI_MODEL_GPT_4o_mini,
+        OPENAI_MODEL_GPT_4o
+    ]
+    openrouter_models = [
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
+    ]
+
+    model_prices = {
+        VOLCENGINE_MODEL_DEEPSEEK_V3: ModelPrice(input_price=2, output_price=8),
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO: ModelPrice(input_price=3, output_price=9),
+        DEEPSEEK_CHAT_MODEL: ModelPrice(input_price=2, output_price=8),
+        OPENAI_MODEL_GPT_4o: ModelPrice(input_price=2.5, output_price=10, currency='USD'),
+        OPENAI_MODEL_GPT_4o_mini: ModelPrice(input_price=0.15, output_price=0.6, currency='USD'),
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET: ModelPrice(input_price=3, output_price=15, currency='USD'),
+    }
+
     @staticmethod
     def create_client(model_name, **kwargs) -> OpenAI:
-        volcengine_models = [
-            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
-            VOLCENGINE_MODEL_DEEPSEEK_V3
-        ]
-        deepseek_models = [
-            DEEPSEEK_CHAT_MODEL,
-        ]
-        openai_models = [
-            OPENAI_MODEL_GPT_4o_mini,
-            OPENAI_MODEL_GPT_4o
-        ]
-        openrouter_models = [
-            OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
-        ]
-        if model_name in volcengine_models:
+        if model_name in OpenAICompatible.volcengine_models:
             llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
-        elif model_name in deepseek_models:
+        elif model_name in OpenAICompatible.deepseek_models:
             llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
-        elif model_name in openai_models:
+        elif model_name in OpenAICompatible.openai_models:
             socks_conf = configs.get().get('system', {}).get('outside_proxy', {}).get('socks5', {})
             if socks_conf:
                 http_client = httpx.Client(
@@ -74,12 +120,38 @@ class OpenAICompatible:
                 )
                 kwargs['http_client'] = http_client
             llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
-        elif model_name in openrouter_models:
+        elif model_name in OpenAICompatible.openrouter_models:
             llm_client = OpenAI(api_key=OPENROUTER_API_TOKEN, base_url=OPENROUTER_BASE_URL, **kwargs)
         else:
             raise Exception("Unsupported model: %s" % model_name)
         return llm_client
 
+    @staticmethod
+    def get_price(model_name: str) -> ModelPrice:
+        """
+        Get the price for a given model.
+        :param model_name: Name of the model
+        :return: ModelPrice object containing input and output prices
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        return OpenAICompatible.model_prices[model_name]
+
+    @staticmethod
+    def calculate_cost(model_name: str, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the cost for a given model based on input and output tokens.
+        :param model_name: Name of the model
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the model's currency
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        price = OpenAICompatible.model_prices[model_name]
+        return price.get_total_cost(input_tokens, output_tokens, convert_to_cny)
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id

+ 2 - 0
pqai_agent/push_service.py

@@ -248,6 +248,8 @@ class PushTaskWorkerPool:
                 ),
                 query_prompt_template=query_prompt_template
             )
+            cost = push_agent.get_total_cost()
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: push message generation cost: {cost}")
             if message_to_user:
                 rmq_message = generate_task_rmq_message(
                     self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))