embedding_manager.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import os
  2. import threading
  3. from filelock import FileLock
  4. from time import sleep
  5. import numpy as np
  6. import random
  7. class EmbeddingManager:
  8. def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
  9. self.model = model
  10. self.emb_size = emb_size
  11. self.cache_file = cache_file
  12. self.cache_file_real = self.cache_file + ".npy"
  13. self.cache_key_file = f'{self.cache_file}.keys'
  14. # avoid multiple process read and write at same time and wait for filelock
  15. self.save_interval = save_interval + random.randint(0, save_interval)
  16. self.cache = {}
  17. self.lock = threading.Lock() # Thread-safe lock
  18. self.filelock = FileLock(self.cache_file + ".lock")
  19. self.load_cache()
  20. # Start the periodic saving thread
  21. self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
  22. self.saving_thread.start()
  23. def _load_cache_unsafe(self):
  24. """inter-thread and inter-process safety must be guaranteed by caller"""
  25. embedding_data = np.load(self.cache_file_real)
  26. embedding_keys = open(self.cache_key_file, "r").readlines()
  27. embedding_keys = [key.strip("\n") for key in embedding_keys]
  28. for idx, key in enumerate(embedding_keys):
  29. self.cache[key] = embedding_data[idx]
  30. def load_cache(self):
  31. with self.lock:
  32. if os.path.exists(self.cache_file_real):
  33. with self.filelock:
  34. self._load_cache_unsafe()
  35. print("[EmbeddingManager]cache loaded")
  36. def dump_cache(self):
  37. if os.path.dirname(self.cache_file):
  38. os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
  39. tmp_cache_file = self.cache_file + ".tmp"
  40. tmp_cache_key_file = self.cache_key_file + ".tmp"
  41. with self.lock: # Ensure thread-safe access firstly
  42. with self.filelock: # Ensure inter-process safety secondly
  43. if os.path.exists(self.cache_file_real):
  44. self._load_cache_unsafe()
  45. keys = self.cache.keys()
  46. cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
  47. for idx, key in enumerate(keys):
  48. cache_to_save[idx] = self.cache[key]
  49. np.save(tmp_cache_file, cache_to_save)
  50. with open(tmp_cache_key_file, 'w') as fp:
  51. fp.write('\n'.join(keys))
  52. if os.path.exists(self.cache_file + ".npy"):
  53. os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
  54. if os.path.exists(self.cache_key_file):
  55. os.rename(self.cache_key_file, self.cache_key_file + ".bak")
  56. os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
  57. os.rename(tmp_cache_key_file, self.cache_key_file)
  58. print("[EmbeddingManager]cache dumped")
  59. def get_embedding(self, text_list):
  60. """
  61. Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
  62. """
  63. if not isinstance(text_list, list):
  64. raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
  65. embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
  66. if not text_list:
  67. return embedding_list
  68. new_texts = []
  69. new_texts_ori_idx = []
  70. with self.lock:
  71. for idx, text in enumerate(text_list):
  72. if text in self.cache:
  73. print(f"find {text} in cache")
  74. embedding_list[idx] = self.cache[text]
  75. else:
  76. new_texts.append(text)
  77. new_texts_ori_idx.append(idx)
  78. new_embeddings = self.model.get_embedding(new_texts)
  79. # Generate embedding if not found in cache
  80. with self.lock: # Ensure thread-safe access
  81. for idx, text in enumerate(new_texts):
  82. if text not in self.cache: # Re-check in case another thread added it
  83. self.cache[text] = new_embeddings[idx]
  84. embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
  85. return embedding_list
  86. def _periodic_save(self):
  87. """Periodically save the cache to disk."""
  88. while True:
  89. sleep(self.save_interval)
  90. self.dump_cache()
  91. # Only for testing
  92. class DummyModel:
  93. def padding_text(self, text):
  94. padding_factor = 1024 // len(text)
  95. text = text * padding_factor
  96. text += text[:1024 - len(text)]
  97. return text
  98. def get_embedding(self, text_list):
  99. embeddings = np.zeros((len(text_list), 1024), np.float32)
  100. for idx, text in enumerate(text_list):
  101. text = self.padding_text(text)
  102. embedding = np.array([ord(c) for c in text], np.float32)
  103. embeddings[idx] = embedding
  104. return embeddings
  105. if __name__ == "__main__":
  106. model = DummyModel()
  107. manager = EmbeddingManager(model)
  108. print(manager.get_embedding(["hello"]))
  109. print(manager.get_embedding(["world"]))
  110. print(manager.get_embedding(["hello world"]))
  111. manager.dump_cache()
  112. print(manager.get_embedding(["new", "word"]))
  113. manager.dump_cache()