Просмотр исходного кода

feat: implement tool pricing settings UI and enhance tool call quota calculations

CaIon 1 месяц назад
Родитель
Сommit
6e3ef48c9b

+ 2 - 1
model/option.go

@@ -486,8 +486,9 @@ func handleConfigUpdate(key, value string) bool {
 
 	// 特定配置的后处理
 	if configName == "performance_setting" {
-		// 同步磁盘缓存配置到 common 包
 		performance_setting.UpdateAndSync()
+	} else if configName == "tool_price_setting" {
+		operation_setting.RebuildToolPriceIndex()
 	}
 
 	return true // 已处理

+ 8 - 3
relay/compatible_handler.go

@@ -290,21 +290,26 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 
 	// Collect tool call usage from context and relayInfo
 	toolUsage := service.ToolCallUsage{
-		WebSearchModelName:     modelName,
-		ClaudeWebSearchCalls:   ctx.GetInt("claude_web_search_requests"),
+		ModelName:              modelName,
 		ImageGenerationCall:    ctx.GetBool("image_generation_call"),
 		ImageGenerationQuality: ctx.GetString("image_generation_call_quality"),
 		ImageGenerationSize:    ctx.GetString("image_generation_call_size"),
 	}
 	if relayInfo.ResponsesUsageInfo != nil {
-		if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
+		if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
 			toolUsage.WebSearchCalls = webSearchTool.CallCount
+			toolUsage.WebSearchToolName = dto.BuildInToolWebSearchPreview
 		}
 		if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
 			toolUsage.FileSearchCalls = fileSearchTool.CallCount
 		}
 	} else if strings.HasSuffix(modelName, "search-preview") {
 		toolUsage.WebSearchCalls = 1
+		toolUsage.WebSearchToolName = dto.BuildInToolWebSearchPreview
+	}
+	if claudeSearchCalls := ctx.GetInt("claude_web_search_requests"); claudeSearchCalls > 0 {
+		toolUsage.WebSearchCalls = claudeSearchCalls
+		toolUsage.WebSearchToolName = "web_search"
 	}
 	toolResult := service.ComputeToolCallQuota(toolUsage, groupRatio)
 	for _, item := range toolResult.Items {

+ 14 - 28
service/tool_billing.go

@@ -2,7 +2,6 @@ package service
 
 import (
 	"math"
-	"strings"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -10,9 +9,9 @@ import (
 
 // ToolCallUsage captures all tool call counts from a single request.
 type ToolCallUsage struct {
+	ModelName              string
 	WebSearchCalls         int
-	WebSearchModelName     string
-	ClaudeWebSearchCalls   int
+	WebSearchToolName      string // "web_search_preview", "web_search", etc.
 	FileSearchCalls        int
 	ImageGenerationCall    bool
 	ImageGenerationQuality string
@@ -34,33 +33,25 @@ type ToolCallResult struct {
 	Items      []ToolCallItem `json:"items,omitempty"`
 }
 
-func getWebSearchPriceKey(modelName string) string {
-	isNormalPrice :=
-		strings.HasPrefix(modelName, "o3") ||
-			strings.HasPrefix(modelName, "o4") ||
-			strings.HasPrefix(modelName, "gpt-5")
-	if isNormalPrice {
-		return "web_search"
-	}
-	return "web_search_high"
-}
-
 // ComputeToolCallQuota calculates the total quota for all tool calls in a
-// request. All tool prices are $/1K calls (configurable via ToolCallPrices
-// option). groupRatio is applied. Per-call billing (UsePrice) callers should
-// NOT add this result — per-call price already includes everything.
+// request. Tool prices are resolved via GetToolPriceForModel which supports
+// model-prefix overrides. groupRatio is applied.
 func ComputeToolCallQuota(usage ToolCallUsage, groupRatio float64) ToolCallResult {
 	var items []ToolCallItem
 	totalQuota := 0
 
-	addItem := func(name string, count int, pricePer1K float64) {
-		if count <= 0 || pricePer1K <= 0 {
+	addItem := func(toolName string, count int) {
+		if count <= 0 {
+			return
+		}
+		pricePer1K := operation_setting.GetToolPriceForModel(toolName, usage.ModelName)
+		if pricePer1K <= 0 {
 			return
 		}
 		totalPrice := pricePer1K * float64(count) / 1000
 		quota := int(math.Round(totalPrice * common.QuotaPerUnit * groupRatio))
 		items = append(items, ToolCallItem{
-			Name:       name,
+			Name:       toolName,
 			CallCount:  count,
 			PricePer1K: pricePer1K,
 			TotalPrice: totalPrice,
@@ -69,17 +60,12 @@ func ComputeToolCallQuota(usage ToolCallUsage, groupRatio float64) ToolCallResul
 		totalQuota += quota
 	}
 
-	if usage.WebSearchCalls > 0 {
-		priceKey := getWebSearchPriceKey(usage.WebSearchModelName)
-		addItem("web_search", usage.WebSearchCalls, operation_setting.GetToolPrice(priceKey))
-	}
-
-	if usage.ClaudeWebSearchCalls > 0 {
-		addItem("claude_web_search", usage.ClaudeWebSearchCalls, operation_setting.GetToolPrice("claude_web_search"))
+	if usage.WebSearchCalls > 0 && usage.WebSearchToolName != "" {
+		addItem(usage.WebSearchToolName, usage.WebSearchCalls)
 	}
 
 	if usage.FileSearchCalls > 0 {
-		addItem("file_search", usage.FileSearchCalls, operation_setting.GetToolPrice("file_search"))
+		addItem("file_search", usage.FileSearchCalls)
 	}
 
 	if usage.ImageGenerationCall {

+ 108 - 13
setting/operation_setting/tools.go

@@ -1,21 +1,36 @@
 package operation_setting
 
 import (
+	"sort"
 	"strings"
+	"sync/atomic"
 
 	"github.com/QuantumNous/new-api/setting/config"
 )
 
 // ---------------------------------------------------------------------------
 // Tool call prices ($/1K calls, admin-configurable)
-// DB keys: tool_price_setting.prices
+// DB key: tool_price_setting.prices
+//
+// Key format:
+//   - "tool_name"              → default price for all models
+//   - "tool_name:model_prefix*" → override for models matching the prefix
+//
+// Lookup order: longest prefix match → default → hardcoded fallback → 0
 // ---------------------------------------------------------------------------
 
 var defaultToolPrices = map[string]float64{
-	"web_search":        10.0,
-	"web_search_high":   25.0,
-	"claude_web_search": 10.0,
-	"file_search":       2.5,
+	"web_search":         10.0, // OpenAI web search (all models) / Claude web search
+	"web_search_preview": 10.0, // OpenAI web search preview (default: reasoning models)
+	"file_search":        2.5,  // OpenAI file search (Responses API)
+	"google_search":      14.0, // Gemini Grounding with Google Search
+}
+
+var defaultToolPriceOverrides = map[string]float64{
+	"web_search_preview:gpt-4o*":       25.0, // non-reasoning models
+	"web_search_preview:gpt-4.1*":      25.0,
+	"web_search_preview:gpt-4o-mini*":  25.0,
+	"web_search_preview:gpt-4.1-mini*": 25.0,
 }
 
 // ToolPriceSetting is managed by config.GlobalConfig.Register.
@@ -25,30 +40,110 @@ type ToolPriceSetting struct {
 
 var toolPriceSetting = ToolPriceSetting{
 	Prices: func() map[string]float64 {
-		m := make(map[string]float64, len(defaultToolPrices))
+		m := make(map[string]float64, len(defaultToolPrices)+len(defaultToolPriceOverrides))
 		for k, v := range defaultToolPrices {
 			m[k] = v
 		}
+		for k, v := range defaultToolPriceOverrides {
+			m[k] = v
+		}
 		return m
 	}(),
 }
 
 func init() {
 	config.GlobalConfig.Register("tool_price_setting", &toolPriceSetting)
+	RebuildToolPriceIndex()
+}
+
+// ---------------------------------------------------------------------------
+// Precomputed price index (atomic, lock-free on read path)
+// ---------------------------------------------------------------------------
+
+type prefixEntry struct {
+	prefix string
+	price  float64
+}
+
+type toolPriceIndex struct {
+	defaults map[string]float64
+	prefixes map[string][]prefixEntry
+}
+
+var currentIndex atomic.Pointer[toolPriceIndex]
+
+// RebuildToolPriceIndex rebuilds the lookup index from the current config.
+// Called on init and after config updates. Not on the billing hot path.
+func RebuildToolPriceIndex() {
+	merged := make(map[string]float64, len(defaultToolPrices)+len(defaultToolPriceOverrides)+len(toolPriceSetting.Prices))
+	for k, v := range defaultToolPrices {
+		merged[k] = v
+	}
+	for k, v := range defaultToolPriceOverrides {
+		merged[k] = v
+	}
+	for k, v := range toolPriceSetting.Prices {
+		merged[k] = v
+	}
+
+	idx := &toolPriceIndex{
+		defaults: make(map[string]float64),
+		prefixes: make(map[string][]prefixEntry),
+	}
+
+	for key, price := range merged {
+		colonIdx := strings.IndexByte(key, ':')
+		if colonIdx < 0 {
+			idx.defaults[key] = price
+			continue
+		}
+		toolName := key[:colonIdx]
+		modelPart := key[colonIdx+1:]
+		prefix := strings.TrimSuffix(modelPart, "*")
+		idx.prefixes[toolName] = append(idx.prefixes[toolName], prefixEntry{prefix: prefix, price: price})
+	}
+
+	for tool := range idx.prefixes {
+		entries := idx.prefixes[tool]
+		sort.Slice(entries, func(i, j int) bool {
+			return len(entries[i].prefix) > len(entries[j].prefix)
+		})
+		idx.prefixes[tool] = entries
+	}
+
+	currentIndex.Store(idx)
 }
 
-// GetToolPrice returns the configured price for a tool key ($/1K calls),
-// falling back to hardcoded default if not overridden.
-func GetToolPrice(key string) float64 {
-	if v, ok := toolPriceSetting.Prices[key]; ok {
-		return v
+// GetToolPriceForModel returns the price ($/1K calls) for a tool given a model name.
+// Lookup: longest prefix match → tool default → 0.
+func GetToolPriceForModel(toolName, modelName string) float64 {
+	idx := currentIndex.Load()
+	if idx == nil {
+		if v, ok := defaultToolPrices[toolName]; ok {
+			return v
+		}
+		return 0
+	}
+
+	if entries, ok := idx.prefixes[toolName]; ok && modelName != "" {
+		for _, e := range entries {
+			if strings.HasPrefix(modelName, e.prefix) {
+				return e.price
+			}
+		}
 	}
-	if v, ok := defaultToolPrices[key]; ok {
-		return v
+
+	if p, ok := idx.defaults[toolName]; ok {
+		return p
 	}
 	return 0
 }
 
+// GetToolPrice is a convenience wrapper when no model name is needed.
+func GetToolPrice(toolName string) float64 {
+	return GetToolPriceForModel(toolName, "")
+}
+
 // ---------------------------------------------------------------------------
 // GPT Image 1 per-call pricing (special: depends on quality + size)
 // ---------------------------------------------------------------------------

+ 4 - 0
web/src/components/settings/RatioSetting.jsx

@@ -26,6 +26,7 @@ import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings';
 import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor';
 import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor';
 import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync';
+import ToolPriceSettings from '../../pages/Setting/Ratio/ToolPriceSettings';
 
 import { API, showError, toBoolean } from '../../helpers';
 
@@ -113,6 +114,9 @@ const RatioSetting = () => {
           <Tabs.TabPane tab={t('上游倍率同步')} itemKey='upstream_sync'>
             <UpstreamRatioSync options={inputs} refresh={onRefresh} />
           </Tabs.TabPane>
+          <Tabs.TabPane tab={t('工具调用定价')} itemKey='tool_price'>
+            <ToolPriceSettings options={inputs} />
+          </Tabs.TabPane>
         </Tabs>
       </Card>
     </Spin>

+ 283 - 0
web/src/pages/Setting/Ratio/ToolPriceSettings.jsx

@@ -0,0 +1,283 @@
+/*
+Copyright (C) 2025 QuantumNous
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, either version 3 of the
+License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+For commercial licensing, please contact support@quantumnous.com
+*/
+import React, { useEffect, useMemo, useState } from 'react';
+import {
+  Banner,
+  Button,
+  Input,
+  InputNumber,
+  Radio,
+  RadioGroup,
+  Table,
+  TextArea,
+  Typography,
+} from '@douyinfe/semi-ui';
+import { IconCopy, IconDelete, IconPlus } from '@douyinfe/semi-icons';
+import { useTranslation } from 'react-i18next';
+import { API, copy, showError, showSuccess } from '../../../helpers';
+
+const { Text } = Typography;
+
+const OPTION_KEY = 'tool_price_setting.prices';
+
+const DEFAULT_PRICES = {
+  web_search: 10.0,
+  web_search_preview: 10.0,
+  'web_search_preview:gpt-4o*': 25.0,
+  'web_search_preview:gpt-4.1*': 25.0,
+  'web_search_preview:gpt-4o-mini*': 25.0,
+  'web_search_preview:gpt-4.1-mini*': 25.0,
+  file_search: 2.5,
+  google_search: 14.0,
+};
+
+function rowsToObject(rows) {
+  const prices = {};
+  for (const row of rows) {
+    const k = row.key.trim();
+    if (!k) continue;
+    prices[k] = Number(row.price) || 0;
+  }
+  return prices;
+}
+
+function objectToRows(prices) {
+  return Object.entries(prices).map(([key, price], i) => ({
+    id: i,
+    key,
+    price,
+  }));
+}
+
+export default function ToolPriceSettings({ options }) {
+  const { t } = useTranslation();
+  const [rows, setRows] = useState([]);
+  const [mode, setMode] = useState('visual');
+  const [jsonText, setJsonText] = useState('');
+  const [jsonError, setJsonError] = useState('');
+  const [saving, setSaving] = useState(false);
+
+  useEffect(() => {
+    let prices = {};
+    try {
+      const raw = options?.[OPTION_KEY];
+      if (raw) {
+        prices = typeof raw === 'string' ? JSON.parse(raw) : raw;
+      }
+    } catch {
+      prices = {};
+    }
+
+    if (!prices || Object.keys(prices).length === 0) {
+      prices = { ...DEFAULT_PRICES };
+    }
+
+    setRows(objectToRows(prices));
+    setJsonText(JSON.stringify(prices, null, 2));
+  }, [options]);
+
+  const syncToJson = (nextRows) => {
+    setRows(nextRows);
+    setJsonText(JSON.stringify(rowsToObject(nextRows), null, 2));
+    setJsonError('');
+  };
+
+  const syncToVisual = (text) => {
+    setJsonText(text);
+    try {
+      const parsed = JSON.parse(text);
+      if (typeof parsed !== 'object' || Array.isArray(parsed)) {
+        setJsonError(t('JSON 必须是对象'));
+        return;
+      }
+      setRows(objectToRows(parsed));
+      setJsonError('');
+    } catch (e) {
+      setJsonError(e.message);
+    }
+  };
+
+  const updateRow = (id, field, value) => {
+    syncToJson(rows.map((r) => (r.id === id ? { ...r, [field]: value } : r)));
+  };
+
+  const addRow = () => {
+    syncToJson([...rows, { id: Date.now(), key: '', price: 0 }]);
+  };
+
+  const removeRow = (id) => {
+    syncToJson(rows.filter((r) => r.id !== id));
+  };
+
+  const resetToDefault = () => {
+    syncToJson(objectToRows(DEFAULT_PRICES));
+  };
+
+  const currentPrices = useMemo(() => rowsToObject(rows), [rows]);
+
+  const handleSave = async () => {
+    setSaving(true);
+    try {
+      const res = await API.put('/api/option/', {
+        key: OPTION_KEY,
+        value: JSON.stringify(currentPrices),
+      });
+      if (res.data.success) {
+        showSuccess(t('保存成功'));
+      } else {
+        showError(res.data.message || t('保存失败'));
+      }
+    } catch (e) {
+      showError(e.message);
+    } finally {
+      setSaving(false);
+    }
+  };
+
+  const columns = [
+    {
+      title: t('工具标识'),
+      dataIndex: 'key',
+      render: (text, record) => (
+        <Input
+          value={text}
+          placeholder='web_search_preview:gpt-4o*'
+          onChange={(val) => updateRow(record.id, 'key', val)}
+          style={{ width: '100%' }}
+        />
+      ),
+    },
+    {
+      title: t('价格') + ' ($/1K' + t('次') + ')',
+      dataIndex: 'price',
+      width: 160,
+      render: (val, record) => (
+        <InputNumber
+          value={val}
+          min={0}
+          step={0.5}
+          onChange={(v) => updateRow(record.id, 'price', v ?? 0)}
+          style={{ width: '100%' }}
+        />
+      ),
+    },
+    {
+      title: t('操作'),
+      width: 60,
+      render: (_, record) => (
+        <Button
+          icon={<IconDelete />}
+          type='danger'
+          theme='borderless'
+          size='small'
+          onClick={() => removeRow(record.id)}
+        />
+      ),
+    },
+  ];
+
+  return (
+    <div style={{ maxWidth: 700 }}>
+      <Banner
+        type='info'
+        description={
+          <>
+            <div>{t('配置各工具的调用价格($/1K次调用)。按次计费模型不额外收取工具费用。')}</div>
+            <div style={{ marginTop: 4 }}>
+              <Text strong>{t('格式')}:</Text>
+              <code>web_search_preview</code> {t('为默认价格')},
+              <code>web_search_preview:gpt-4o*</code> {t('为模型前缀覆盖')}
+            </div>
+          </>
+        }
+        style={{ marginBottom: 16 }}
+      />
+
+      <RadioGroup
+        type='button'
+        size='small'
+        value={mode}
+        onChange={(e) => setMode(e.target.value)}
+        style={{ marginBottom: 12 }}
+      >
+        <Radio value='visual'>{t('可视化')}</Radio>
+        <Radio value='json'>JSON</Radio>
+      </RadioGroup>
+
+      {mode === 'visual' ? (
+        <>
+          <Table
+            dataSource={rows}
+            columns={columns}
+            pagination={false}
+            size='small'
+            rowKey='id'
+          />
+          <div style={{ display: 'flex', gap: 8, marginTop: 12 }}>
+            <Button icon={<IconPlus />} onClick={addRow}>
+              {t('添加')}
+            </Button>
+            <Button theme='borderless' onClick={resetToDefault}>
+              {t('恢复默认')}
+            </Button>
+          </div>
+        </>
+      ) : (
+        <>
+          <TextArea
+            value={jsonText}
+            onChange={syncToVisual}
+            autosize={{ minRows: 8, maxRows: 20 }}
+            style={{ fontFamily: 'monospace', fontSize: 13 }}
+          />
+          {jsonError && (
+            <Text type='danger' size='small' style={{ display: 'block', marginTop: 4 }}>
+              {jsonError}
+            </Text>
+          )}
+          <div style={{ display: 'flex', gap: 8, marginTop: 8 }}>
+            <Button
+              icon={<IconCopy />}
+              size='small'
+              theme='borderless'
+              onClick={() => { copy(jsonText, t('JSON')); }}
+            >
+              {t('复制')}
+            </Button>
+            <Button size='small' theme='borderless' onClick={resetToDefault}>
+              {t('恢复默认')}
+            </Button>
+          </div>
+        </>
+      )}
+
+      <div style={{ display: 'flex', justifyContent: 'flex-end', marginTop: 16 }}>
+        <Button
+          theme='solid'
+          type='primary'
+          loading={saving}
+          disabled={mode === 'json' && !!jsonError}
+          onClick={handleSave}
+        >
+          {t('保存')}
+        </Button>
+      </div>
+    </div>
+  );
+}