rerank.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package common_handler
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. "one-api/relay/channel/xinference"
  9. relaycommon "one-api/relay/common"
  10. "one-api/service"
  11. )
  12. func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  13. responseBody, err := io.ReadAll(resp.Body)
  14. if err != nil {
  15. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  16. }
  17. common.CloseResponseBodyGracefully(resp)
  18. if common.DebugEnabled {
  19. println("reranker response body: ", string(responseBody))
  20. }
  21. var jinaResp dto.RerankResponse
  22. if info.ChannelType == common.ChannelTypeXinference {
  23. var xinRerankResponse xinference.XinRerankResponse
  24. err = common.UnmarshalJson(responseBody, &xinRerankResponse)
  25. if err != nil {
  26. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  27. }
  28. jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
  29. for i, result := range xinRerankResponse.Results {
  30. respResult := dto.RerankResponseResult{
  31. Index: result.Index,
  32. RelevanceScore: result.RelevanceScore,
  33. }
  34. if info.ReturnDocuments {
  35. var document any
  36. if result.Document != nil {
  37. if doc, ok := result.Document.(string); ok {
  38. if doc == "" {
  39. document = info.Documents[result.Index]
  40. } else {
  41. document = doc
  42. }
  43. } else {
  44. document = result.Document
  45. }
  46. }
  47. respResult.Document = document
  48. }
  49. jinaRespResults[i] = respResult
  50. }
  51. jinaResp = dto.RerankResponse{
  52. Results: jinaRespResults,
  53. Usage: dto.Usage{
  54. PromptTokens: info.PromptTokens,
  55. TotalTokens: info.PromptTokens,
  56. },
  57. }
  58. } else {
  59. err = common.UnmarshalJson(responseBody, &jinaResp)
  60. if err != nil {
  61. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  62. }
  63. jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
  64. }
  65. c.Writer.Header().Set("Content-Type", "application/json")
  66. c.JSON(http.StatusOK, jinaResp)
  67. return nil, &jinaResp.Usage
  68. }