rerank.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package common_handler
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. "one-api/relay/channel/xinference"
  9. relaycommon "one-api/relay/common"
  10. "one-api/service"
  11. )
  12. func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  13. responseBody, err := io.ReadAll(resp.Body)
  14. if err != nil {
  15. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  16. }
  17. err = resp.Body.Close()
  18. if err != nil {
  19. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  20. }
  21. if common.DebugEnabled {
  22. println("reranker response body: ", string(responseBody))
  23. }
  24. var jinaResp dto.RerankResponse
  25. if info.ChannelType == common.ChannelTypeXinference {
  26. var xinRerankResponse xinference.XinRerankResponse
  27. err = common.DecodeJson(responseBody, &xinRerankResponse)
  28. if err != nil {
  29. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  30. }
  31. jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
  32. for i, result := range xinRerankResponse.Results {
  33. respResult := dto.RerankResponseResult{
  34. Index: result.Index,
  35. RelevanceScore: result.RelevanceScore,
  36. }
  37. if info.ReturnDocuments {
  38. var document any
  39. if result.Document == "" {
  40. document = info.Documents[result.Index]
  41. } else {
  42. document = result.Document
  43. }
  44. respResult.Document = document
  45. }
  46. jinaRespResults[i] = respResult
  47. }
  48. jinaResp = dto.RerankResponse{
  49. Results: jinaRespResults,
  50. Usage: dto.Usage{
  51. PromptTokens: info.PromptTokens,
  52. TotalTokens: info.PromptTokens,
  53. },
  54. }
  55. } else {
  56. err = common.DecodeJson(responseBody, &jinaResp)
  57. if err != nil {
  58. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  59. }
  60. jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
  61. }
  62. c.Writer.Header().Set("Content-Type", "application/json")
  63. c.JSON(http.StatusOK, jinaResp)
  64. return nil, &jinaResp.Usage
  65. }