Przeglądaj źródła

feat: 添加流模式下的SSE保活机制 #945

CaIon 10 miesięcy temu
rodzic
commit
2f3acd9d22

+ 0 - 1
relay/channel/openai/relay-openai.go

@@ -141,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			if err != nil {
 				common.SysError("error handling stream format: " + err.Error())
 			}
-			info.SetFirstResponseTime()
 		}
 		lastStreamData = data
 		streamItems = append(streamItems, data)

+ 9 - 0
relay/common/relay_info.go

@@ -6,6 +6,7 @@ import (
 	"one-api/dto"
 	relayconstant "one-api/relay/constant"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -54,6 +55,7 @@ type RelayInfo struct {
 	StartTime         time.Time
 	FirstResponseTime time.Time
 	isFirstResponse   bool
+	responseMutex     sync.Mutex // Add mutex for protecting concurrent access
 	//SendLastReasoningResponse bool
 	ApiType           int
 	IsStream          bool
@@ -212,12 +214,19 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
 }
 
 func (info *RelayInfo) SetFirstResponseTime() {
+	info.responseMutex.Lock()
+	defer info.responseMutex.Unlock()
+
 	if info.isFirstResponse {
 		info.FirstResponseTime = time.Now()
 		info.isFirstResponse = false
 	}
 }
 
+func (info *RelayInfo) HasSendResponse() bool {
+	return info.FirstResponseTime.After(info.StartTime)
+}
+
 type TaskRelayInfo struct {
 	*RelayInfo
 	Action       string

+ 10 - 0
relay/helper/common.go

@@ -55,6 +55,16 @@ func StringData(c *gin.Context, str string) error {
 	return nil
 }
 
+func PingData(c *gin.Context) error {
+	c.Writer.Write([]byte(": PING\n\n"))
+	if flusher, ok := c.Writer.(http.Flusher); ok {
+		flusher.Flush()
+	} else {
+		return errors.New("streaming error: flusher not found")
+	}
+	return nil
+}
+
 func ObjectData(c *gin.Context, object interface{}) error {
 	if object == nil {
 		return errors.New("object is nil")

+ 56 - 4
relay/helper/stream_scanner.go

@@ -3,12 +3,15 @@ package helper
 import (
 	"bufio"
 	"context"
+	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
 	relaycommon "one-api/relay/common"
+	"one-api/setting/operation_setting"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -17,11 +20,12 @@ import (
 const (
 	InitialScannerBufferSize = 1 << 20  // 1MB (1*1024*1024)
 	MaxScannerBufferSize     = 10 << 20 // 10MB (10*1024*1024)
+	DefaultPingInterval      = 10 * time.Second
 )
 
 func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
 
-	if resp == nil {
+	if resp == nil || dataHandler == nil {
 		return
 	}
 
@@ -34,13 +38,29 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 	}
 
 	var (
-		stopChan = make(chan bool, 2)
-		scanner  = bufio.NewScanner(resp.Body)
-		ticker   = time.NewTicker(streamingTimeout)
+		stopChan   = make(chan bool, 2)
+		scanner    = bufio.NewScanner(resp.Body)
+		ticker     = time.NewTicker(streamingTimeout)
+		pingTicker *time.Ticker
+		writeMutex sync.Mutex // Mutex to protect concurrent writes
 	)
 
+	generalSettings := operation_setting.GetGeneralSetting()
+	pingEnabled := generalSettings.PingIntervalEnabled
+	pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+	if pingInterval <= 0 {
+		pingInterval = DefaultPingInterval
+	}
+
+	if pingEnabled {
+		pingTicker = time.NewTicker(pingInterval)
+	}
+
 	defer func() {
 		ticker.Stop()
+		if pingTicker != nil {
+			pingTicker.Stop()
+		}
 		close(stopChan)
 	}()
 	scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
@@ -51,6 +71,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 	defer cancel()
 
 	ctx = context.WithValue(ctx, "stop_chan", stopChan)
+
+	// Handle ping data sending
+	if pingEnabled && pingTicker != nil {
+		gopool.Go(func() {
+			for {
+				select {
+				case <-pingTicker.C:
+					writeMutex.Lock() // Lock before writing
+					err := PingData(c)
+					writeMutex.Unlock() // Unlock after writing
+					if err != nil {
+						common.LogError(c, "ping data error: "+err.Error())
+						common.SafeSendBool(stopChan, true)
+						return
+					}
+					if common.DebugEnabled {
+						println("ping data sent")
+					}
+				case <-ctx.Done():
+					if common.DebugEnabled {
+						println("ping data goroutine stopped")
+					}
+					return
+				}
+			}
+		})
+	}
+
 	common.RelayCtxGo(ctx, func() {
 		for scanner.Scan() {
 			ticker.Reset(streamingTimeout)
@@ -70,7 +118,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 			data = strings.TrimSuffix(data, "\"")
 			if !strings.HasPrefix(data, "[DONE]") {
 				info.SetFirstResponseTime()
+				writeMutex.Lock() // Lock before writing
 				success := dataHandler(data)
+				writeMutex.Unlock() // Unlock after writing
 				if !success {
 					break
 				}
@@ -90,7 +140,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 	case <-ticker.C:
 		// 超时处理逻辑
 		common.LogError(c, "streaming timeout")
+		common.SafeSendBool(stopChan, true)
 	case <-stopChan:
 		// 正常结束
+		common.LogInfo(c, "streaming finished")
 	}
 }

+ 6 - 2
setting/operation_setting/general_setting.go

@@ -3,12 +3,16 @@ package operation_setting
 import "one-api/setting/config"
 
 type GeneralSetting struct {
-	DocsLink string `json:"docs_link"`
+	DocsLink            string `json:"docs_link"`
+	PingIntervalEnabled bool   `json:"ping_interval_enabled"`
+	PingIntervalSeconds int    `json:"ping_interval_seconds"`
 }
 
 // 默认配置
 var generalSetting = GeneralSetting{
-	DocsLink: "https://docs.newapi.pro",
+	DocsLink:            "https://docs.newapi.pro",
+	PingIntervalEnabled: false,
+	PingIntervalSeconds: 60,
 }
 
 func init() {

+ 2 - 0
web/src/components/ModelSetting.js

@@ -18,6 +18,8 @@ const ModelSetting = () => {
     'claude.default_max_tokens': '',
     'claude.thinking_adapter_budget_tokens_percentage': 0.8,
     'global.pass_through_request_enabled': false,
+    'general_setting.ping_interval_enabled': false,
+    'general_setting.ping_interval_seconds': 60,
   });
 
   let [loading, setLoading] = useState(false);

+ 18 - 17
web/src/components/PersonalSetting.js

@@ -793,23 +793,7 @@ const PersonalSetting = () => {
               </div>
             </Card>
             <Card style={{ marginTop: 10 }}>
-              <Tabs type="line" defaultActiveKey="price">
-                <TabPane tab={t('价格设置')} itemKey="price">
-                  <div style={{ marginTop: 20 }}>
-                    <Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
-                    <div style={{ marginTop: 10 }}>
-                      <Checkbox
-                        checked={notificationSettings.acceptUnsetModelRatioModel}
-                        onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
-                      >
-                        {t('接受未设置价格模型')}
-                      </Checkbox>
-                      <Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
-                        {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
-                      </Typography.Text>
-                    </div>
-                  </div>
-                </TabPane>
+              <Tabs type="line" defaultActiveKey="notification">
                 <TabPane tab={t('通知设置')} itemKey="notification">
                   <div style={{ marginTop: 20 }}>
                     <Typography.Text strong>{t('通知方式')}</Typography.Text>
@@ -923,6 +907,23 @@ const PersonalSetting = () => {
                     </Typography.Text>
                   </div>
                 </TabPane>
+                <TabPane tab={t('价格设置')} itemKey="price">
+                  <div style={{ marginTop: 20 }}>
+                    <Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text>
+                    <div style={{ marginTop: 10 }}>
+                      <Checkbox
+                        checked={notificationSettings.acceptUnsetModelRatioModel}
+                        onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)}
+                      >
+                        {t('接受未设置价格模型')}
+                      </Checkbox>
+                      <Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}>
+                        {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
+                      </Typography.Text>
+                    </div>
+                  </div>
+                </TabPane>
+                
               </Tabs>
               <div style={{ marginTop: 20 }}>
                 <Button type="primary" onClick={saveNotificationSettings}>

+ 35 - 7
web/src/pages/Setting/Model/SettingGlobalModel.js

@@ -1,5 +1,5 @@
 import React, { useEffect, useState, useRef } from 'react';
-import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
+import { Button, Col, Form, Row, Spin, Banner } from '@douyinfe/semi-ui';
 import {
   compareObjects,
   API,
@@ -15,6 +15,8 @@ export default function SettingGlobalModel(props) {
   const [loading, setLoading] = useState(false);
   const [inputs, setInputs] = useState({
     'global.pass_through_request_enabled': false,
+    'general_setting.ping_interval_enabled': false,
+    'general_setting.ping_interval_seconds': 60,
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -23,12 +25,8 @@ export default function SettingGlobalModel(props) {
     const updateArray = compareObjects(inputs, inputsRow);
     if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
     const requestQueue = updateArray.map((item) => {
-      let value = '';
-      if (typeof inputs[item.key] === 'boolean') {
-        value = String(inputs[item.key]);
-      } else {
-        value = inputs[item.key];
-      }
+      let value = String(inputs[item.key]);
+
       return API.put('/api/option/', {
         key: item.key,
         value,
@@ -84,6 +82,36 @@ export default function SettingGlobalModel(props) {
                 />
               </Col>
             </Row>
+            
+            <Form.Section text={t('连接保活设置')}>
+            <Row style={{ marginTop: 10 }}>
+                  <Col span={24}>
+                    <Banner 
+                      type="warning"
+                      description="警告:启用保活后,如果已经写入保活数据后渠道出错,系统无法重试,如果必须开启,推荐设置尽可能大的Ping间隔"
+                    />
+                  </Col>
+                </Row>
+              <Row>
+                <Col xs={24} sm={12} md={8} lg={8} xl={8}>
+                  <Form.Switch
+                    label={t('启用Ping间隔')}
+                    field={'general_setting.ping_interval_enabled'}
+                    onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_enabled': value })}
+                    extraText={'开启后,将定期发送ping数据保持连接活跃'}
+                  />
+                </Col>
+                <Col xs={24} sm={12} md={8} lg={8} xl={8}>
+                  <Form.InputNumber
+                    label={t('Ping间隔(秒)')}
+                    field={'general_setting.ping_interval_seconds'}
+                    onChange={(value) => setInputs({ ...inputs, 'general_setting.ping_interval_seconds': value })}
+                    min={1}
+                    disabled={!inputs['general_setting.ping_interval_enabled']}
+                  />
+                </Col>
+              </Row>
+            </Form.Section>
 
             <Row>
               <Button size='default' onClick={onSubmit}>