sensitive.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/constant"
  6. "one-api/dto"
  7. "strings"
  8. )
  9. func CheckSensitiveMessages(messages []dto.Message) error {
  10. for _, message := range messages {
  11. if len(message.Content) > 0 {
  12. if message.IsStringContent() {
  13. stringContent := message.StringContent()
  14. if ok, words := SensitiveWordContains(stringContent); ok {
  15. return errors.New("sensitive words: " + strings.Join(words, ","))
  16. }
  17. }
  18. } else {
  19. arrayContent := message.ParseContent()
  20. for _, m := range arrayContent {
  21. if m.Type == "image_url" {
  22. // TODO: check image url
  23. } else {
  24. if ok, words := SensitiveWordContains(m.Text); ok {
  25. return errors.New("sensitive words: " + strings.Join(words, ","))
  26. }
  27. }
  28. }
  29. }
  30. }
  31. return nil
  32. }
  33. func CheckSensitiveText(text string) error {
  34. if ok, words := SensitiveWordContains(text); ok {
  35. return errors.New("sensitive words: " + strings.Join(words, ","))
  36. }
  37. return nil
  38. }
  39. func CheckSensitiveInput(input any) error {
  40. switch v := input.(type) {
  41. case string:
  42. return CheckSensitiveText(v)
  43. case []string:
  44. text := ""
  45. for _, s := range v {
  46. text += s
  47. }
  48. return CheckSensitiveText(text)
  49. }
  50. return CheckSensitiveText(fmt.Sprintf("%v", input))
  51. }
  52. // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
  53. func SensitiveWordContains(text string) (bool, []string) {
  54. if len(constant.SensitiveWords) == 0 {
  55. return false, nil
  56. }
  57. checkText := strings.ToLower(text)
  58. // 构建一个AC自动机
  59. m := InitAc()
  60. hits := m.MultiPatternSearch([]rune(checkText), false)
  61. if len(hits) > 0 {
  62. words := make([]string, 0)
  63. for _, hit := range hits {
  64. words = append(words, string(hit.Word))
  65. }
  66. return true, words
  67. }
  68. return false, nil
  69. }
  70. // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
  71. func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
  72. if len(constant.SensitiveWords) == 0 {
  73. return false, nil, text
  74. }
  75. checkText := strings.ToLower(text)
  76. m := InitAc()
  77. hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
  78. if len(hits) > 0 {
  79. words := make([]string, 0)
  80. for _, hit := range hits {
  81. pos := hit.Pos
  82. word := string(hit.Word)
  83. text = text[:pos] + "**###**" + text[pos+len(word):]
  84. words = append(words, word)
  85. }
  86. return true, words, text
  87. }
  88. return false, nil, text
  89. }