recall.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "reflect"
  6. "runtime/debug"
  7. "strings"
  8. "sync"
  9. "time"
  10. "gitlab.alibaba-inc.com/pai_biz_arch/pairec/context"
  11. "gitlab.alibaba-inc.com/pai_biz_arch/pairec/log"
  12. "gitlab.alibaba-inc.com/pai_biz_arch/pairec/module"
  13. "gitlab.alibaba-inc.com/pai_biz_arch/pairec/service/recall"
  14. "gitlab.alibaba-inc.com/pai_biz_arch/pairec/utils"
  15. )
  16. var mutex sync.Mutex
  17. type RecallService struct {
  18. }
  19. func getRecallFromABConfig(recallConfig map[string]interface{}, name, recallNewName string) recall.Recall {
  20. oldRecall, err := recall.GetRecall(name)
  21. if err != nil {
  22. return nil
  23. }
  24. mutex.Lock()
  25. if newRecall, err := recall.GetRecall(recallNewName); err == nil {
  26. mutex.Unlock()
  27. return newRecall
  28. }
  29. if m := reflect.ValueOf(oldRecall).MethodByName("CloneWithConfig"); m.IsValid() {
  30. if callValues := m.Call([]reflect.Value{reflect.ValueOf(recallConfig)}); len(callValues) == 1 {
  31. i := callValues[0].Interface()
  32. if newRecall, ok := i.(recall.Recall); ok {
  33. recall.RegisterRecall(recallNewName, newRecall)
  34. log.Info("register recall :" + recallNewName)
  35. mutex.Unlock()
  36. return newRecall
  37. }
  38. }
  39. }
  40. mutex.Unlock()
  41. return nil
  42. }
  43. func (s *RecallService) GetItems(user *module.User, context *context.RecommendContext) (ret []*module.Item) {
  44. sceneName := context.GetParameter("scene")
  45. categoryName := context.GetParameter("category").(string)
  46. var recalls []recall.Recall
  47. var recallNames []string
  48. if context.ExperimentResult != nil {
  49. names := context.ExperimentResult.GetExperimentParams().Get(categoryName+".RecallNames", nil)
  50. if names != nil {
  51. if values, ok := names.([]interface{}); ok {
  52. for _, v := range values {
  53. if r, ok := v.(string); ok {
  54. recallNames = append(recallNames, r)
  55. }
  56. }
  57. }
  58. }
  59. }
  60. if len(recallNames) == 0 {
  61. scene, err1 := GetSence(sceneName.(string))
  62. if err1 != nil {
  63. log.Error(fmt.Sprintf("requestId=%s\tmodule=recall\terror=%v", context.RecommendId, err1))
  64. return
  65. }
  66. category, err2 := scene.GetCategory(categoryName)
  67. if err2 != nil {
  68. log.Error(fmt.Sprintf("requestId=%s\tmodule=recall\terror=%v", context.RecommendId, err2))
  69. return
  70. }
  71. // recalls = category.GetRecalls()
  72. recallNames = category.GetRecallNames()
  73. }
  74. for _, name := range recallNames {
  75. if context.ExperimentResult != nil {
  76. recallConfig := context.ExperimentResult.GetExperimentParams().Get("recall."+name, nil)
  77. if recallConfig == nil {
  78. if recall, err := recall.GetRecall(name); err == nil {
  79. recalls = append(recalls, recall)
  80. }
  81. } else {
  82. d, _ := json.Marshal(recallConfig)
  83. recallName := name + "#" + utils.Md5(string(d))
  84. // find new recall by the new recall name
  85. if recall, err := recall.GetRecall(recallName); err == nil {
  86. recalls = append(recalls, recall)
  87. } else {
  88. if params, ok := recallConfig.(map[string]interface{}); ok {
  89. if r := getRecallFromABConfig(params, name, recallName); r != nil {
  90. recalls = append(recalls, r)
  91. }
  92. }
  93. }
  94. }
  95. } else {
  96. // not find abtest config
  97. if recall, err := recall.GetRecall(name); err == nil {
  98. recalls = append(recalls, recall)
  99. }
  100. }
  101. }
  102. ch := make(chan []*module.Item, len(recalls))
  103. start := time.Now()
  104. for i := 0; i < len(recalls); i++ {
  105. go func(ch chan<- []*module.Item, recall recall.Recall) {
  106. // when recall is panic, can recover it
  107. defer func() {
  108. if err := recover(); err != nil {
  109. stack := string(debug.Stack())
  110. log.Error(fmt.Sprintf("error=%v, stack=%s", err, strings.ReplaceAll(stack, "\n", "\t")))
  111. var tmp []*module.Item
  112. ch <- tmp
  113. }
  114. }()
  115. items := recall.GetCandidateItems(user, context)
  116. ch <- items
  117. }(ch, recalls[i])
  118. }
  119. for i := 0; i < len(recalls); i++ {
  120. items := <-ch
  121. ret = append(ret, items...)
  122. }
  123. close(ch)
  124. log.Info(fmt.Sprintf("requestId=%s\tmodule=recall\tcost=%d", context.RecommendId, utils.CostTime(start)))
  125. return
  126. }