rerank.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package common_handler
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. "one-api/relay/channel/xinference"
  9. relaycommon "one-api/relay/common"
  10. "one-api/service"
  11. )
  12. func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  13. responseBody, err := io.ReadAll(resp.Body)
  14. if err != nil {
  15. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  16. }
  17. err = resp.Body.Close()
  18. if err != nil {
  19. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  20. }
  21. if common.DebugEnabled {
  22. println("reranker response body: ", string(responseBody))
  23. }
  24. var jinaResp dto.RerankResponse
  25. if info.ChannelType == common.ChannelTypeXinference {
  26. var xinRerankResponse xinference.XinRerankResponse
  27. err = common.DecodeJson(responseBody, &xinRerankResponse)
  28. if err != nil {
  29. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  30. }
  31. jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
  32. for i, result := range xinRerankResponse.Results {
  33. var document any
  34. if result.Document == "" {
  35. document = info.Documents[result.Index]
  36. } else {
  37. document = result.Document
  38. }
  39. jinaRespResults[i] = dto.RerankResponseResult{
  40. Index: result.Index,
  41. RelevanceScore: result.RelevanceScore,
  42. Document: dto.RerankDocument{
  43. Text: document,
  44. },
  45. }
  46. }
  47. jinaResp = dto.RerankResponse{
  48. Results: jinaRespResults,
  49. Usage: dto.Usage{
  50. PromptTokens: info.PromptTokens,
  51. TotalTokens: info.PromptTokens,
  52. },
  53. }
  54. } else {
  55. err = common.DecodeJson(responseBody, &jinaResp)
  56. if err != nil {
  57. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  58. }
  59. jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
  60. }
  61. c.Writer.Header().Set("Content-Type", "application/json")
  62. c.JSON(http.StatusOK, jinaResp)
  63. return nil, &jinaResp.Usage
  64. }