rerank.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package ali
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/service"
  10. "github.com/gin-gonic/gin"
  11. )
  12. func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
  13. returnDocuments := request.ReturnDocuments
  14. if returnDocuments == nil {
  15. t := true
  16. returnDocuments = &t
  17. }
  18. return &AliRerankRequest{
  19. Model: request.Model,
  20. Input: AliRerankInput{
  21. Query: request.Query,
  22. Documents: request.Documents,
  23. },
  24. Parameters: AliRerankParameters{
  25. TopN: &request.TopN,
  26. ReturnDocuments: returnDocuments,
  27. },
  28. }
  29. }
  30. func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  31. responseBody, err := io.ReadAll(resp.Body)
  32. if err != nil {
  33. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  34. }
  35. common.CloseResponseBodyGracefully(resp)
  36. var aliResponse AliRerankResponse
  37. err = json.Unmarshal(responseBody, &aliResponse)
  38. if err != nil {
  39. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  40. }
  41. if aliResponse.Code != "" {
  42. return &dto.OpenAIErrorWithStatusCode{
  43. Error: dto.OpenAIError{
  44. Message: aliResponse.Message,
  45. Type: aliResponse.Code,
  46. Param: aliResponse.RequestId,
  47. Code: aliResponse.Code,
  48. },
  49. StatusCode: resp.StatusCode,
  50. }, nil
  51. }
  52. usage := dto.Usage{
  53. PromptTokens: aliResponse.Usage.TotalTokens,
  54. CompletionTokens: 0,
  55. TotalTokens: aliResponse.Usage.TotalTokens,
  56. }
  57. rerankResponse := dto.RerankResponse{
  58. Results: aliResponse.Output.Results,
  59. Usage: usage,
  60. }
  61. jsonResponse, err := json.Marshal(rerankResponse)
  62. if err != nil {
  63. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  64. }
  65. c.Writer.Header().Set("Content-Type", "application/json")
  66. c.Writer.WriteHeader(resp.StatusCode)
  67. _, err = c.Writer.Write(jsonResponse)
  68. if err != nil {
  69. return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
  70. }
  71. return nil, &usage
  72. }