embedding_manager.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import threading
  3. import multiprocessing
  4. from time import sleep
  5. import numpy as np
  6. class EmbeddingManager:
  7. def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
  8. self.model = None
  9. self.emb_size = emb_size
  10. self.cache_file = cache_file
  11. self.cache_key_file = f'{self.cache_file}.keys'
  12. self.save_interval = save_interval
  13. self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
  14. self.lock = threading.Lock() # Thread-safe lock
  15. npy_filename = f'{self.cache_file}.npy'
  16. # Load cache from file if it exists
  17. if os.path.exists(npy_filename):
  18. embedding_data = np.load(npy_filename)
  19. embedding_keys = open(self.cache_key_file, "r").readlines()
  20. embedding_keys = [key.strip("\n") for key in embedding_keys]
  21. for idx, key in enumerate(embedding_keys):
  22. self.cache[key] = embedding_data[idx]
  23. print("cache loaded:")
  24. print(self.cache)
  25. # Start the periodic saving thread
  26. self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
  27. self.saving_thread.start()
  28. def set_model(self, model):
  29. """Since each process has its own model, model must be set after process is forked"""
  30. self.model = model
  31. def get_embedding(self, text_list):
  32. """
  33. Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
  34. """
  35. if not isinstance(text_list, list):
  36. raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
  37. embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
  38. if not text_list:
  39. return embedding_list
  40. new_texts = []
  41. new_texts_ori_idx = []
  42. for idx, text in enumerate(text_list):
  43. if text in self.cache:
  44. print(f"find {text} in cache")
  45. embedding_list[idx] = self.cache[text]
  46. else:
  47. new_texts.append(text)
  48. new_texts_ori_idx.append(idx)
  49. new_embeddings = self.model.get_embedding(new_texts)
  50. # Generate embedding if not found in cache
  51. with self.lock: # Ensure thread-safe access
  52. for idx, text in enumerate(new_texts):
  53. if text not in self.cache: # Re-check in case another thread added it
  54. self.cache[text] = new_embeddings[idx]
  55. embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
  56. return embedding_list
  57. def _periodic_save(self):
  58. """Periodically save the cache to disk."""
  59. while True:
  60. sleep(self.save_interval)
  61. self.save_now()
  62. def save_now(self):
  63. """Manually trigger a save to disk."""
  64. if os.path.dirname(self.cache_file):
  65. os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
  66. tmp_cache_file = self.cache_file + ".tmp"
  67. tmp_cache_key_file = self.cache_key_file + ".tmp"
  68. with self.lock: # Ensure thread-safe access
  69. keys = self.cache.keys()
  70. cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
  71. for idx, key in enumerate(keys):
  72. cache_to_save[idx] = self.cache[key]
  73. np.save(tmp_cache_file, cache_to_save)
  74. with open(tmp_cache_key_file, 'w') as fp:
  75. fp.write('\n'.join(keys))
  76. if os.path.exists(self.cache_file + ".npy"):
  77. os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
  78. if os.path.exists(self.cache_key_file):
  79. os.rename(self.cache_key_file, self.cache_key_file + ".bak")
  80. os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
  81. os.rename(tmp_cache_key_file, self.cache_key_file)
  82. # Only for testing
  83. class DummyModel:
  84. def padding_text(self, text):
  85. padding_factor = 1024 // len(text)
  86. text = text * padding_factor
  87. text += text[:1024 - len(text)]
  88. return text
  89. def get_embedding(self, text_list):
  90. embeddings = np.zeros((len(text_list), 1024), np.float32)
  91. for idx, text in enumerate(text_list):
  92. text = self.padding_text(text)
  93. embedding = np.array([ord(c) for c in text], np.float32)
  94. embeddings[idx] = embedding
  95. return embeddings
  96. if __name__ == "__main__":
  97. model = DummyModel()
  98. manager = EmbeddingManager()
  99. manager.set_model(model)
  100. print(manager.get_embedding(["hello"]))
  101. print(manager.get_embedding(["world"]))
  102. print(manager.get_embedding(["hello world"]))
  103. manager.save_now()
  104. print(manager.get_embedding(["new", "word"]))
  105. manager.save_now()