embedding_manager.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import threading
  3. import multiprocessing
  4. from time import sleep
  5. import numpy as np
  6. class EmbeddingManager:
  7. def __init__(self, model, emb_size = 1024, cache_file="embedding_cache", save_interval=600):
  8. self.model = model
  9. self.emb_size = emb_size
  10. self.cache_file = cache_file
  11. self.cache_key_file = f'{self.cache_file}.keys'
  12. self.save_interval = save_interval
  13. self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
  14. self.lock = threading.Lock() # Thread-safe lock
  15. npy_filename = f'{self.cache_file}.npy'
  16. # Load cache from file if it exists
  17. if os.path.exists(npy_filename):
  18. embedding_data = np.load(npy_filename)
  19. embedding_keys = open(self.cache_key_file, "r").readlines()
  20. embedding_keys = [key.strip("\n") for key in embedding_keys]
  21. for idx, key in enumerate(embedding_keys):
  22. self.cache[key] = embedding_data[idx]
  23. print("cache loaded:")
  24. print(self.cache)
  25. # Start the periodic saving thread
  26. self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
  27. self.saving_thread.start()
  28. def get_embedding(self, text_list):
  29. """
  30. Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
  31. """
  32. if not isinstance(text_list, list):
  33. raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
  34. embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
  35. if not text_list:
  36. return embedding_list
  37. new_texts = []
  38. new_texts_ori_idx = []
  39. for idx, text in enumerate(text_list):
  40. if text in self.cache:
  41. print(f"find {text} in cache")
  42. embedding_list[idx] = self.cache[text]
  43. else:
  44. new_texts.append(text)
  45. new_texts_ori_idx.append(idx)
  46. new_embeddings = self.model.get_embedding(new_texts)
  47. # Generate embedding if not found in cache
  48. with self.lock: # Ensure thread-safe access
  49. for idx, text in enumerate(new_texts):
  50. if text not in self.cache: # Re-check in case another thread added it
  51. self.cache[text] = new_embeddings[idx]
  52. embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
  53. return embedding_list
  54. def _periodic_save(self):
  55. """Periodically save the cache to disk."""
  56. while True:
  57. sleep(self.save_interval)
  58. self.save_now()
  59. def save_now(self):
  60. """Manually trigger a save to disk."""
  61. tmp_cache_file = self.cache_file + ".tmp"
  62. tmp_cache_key_file = self.cache_key_file + ".tmp"
  63. with self.lock: # Ensure thread-safe access
  64. keys = self.cache.keys()
  65. cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
  66. for idx, key in enumerate(keys):
  67. cache_to_save[idx] = self.cache[key]
  68. np.save(tmp_cache_file, cache_to_save)
  69. with open(tmp_cache_key_file, 'w') as fp:
  70. fp.write('\n'.join(keys))
  71. if os.path.exists(self.cache_file + ".npy"):
  72. os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
  73. if os.path.exists(self.cache_key_file):
  74. os.rename(self.cache_key_file, self.cache_key_file + ".bak")
  75. os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
  76. os.rename(tmp_cache_key_file, self.cache_key_file)
  77. # Only for testing
  78. class DummyModel:
  79. def padding_text(self, text):
  80. padding_factor = 1024 // len(text)
  81. text = text * padding_factor
  82. text += text[:1024 - len(text)]
  83. return text
  84. def get_embedding(self, text_list):
  85. embeddings = np.zeros((len(text_list), 1024), np.float32)
  86. for idx, text in enumerate(text_list):
  87. text = self.padding_text(text)
  88. embedding = np.array([ord(c) for c in text], np.float32)
  89. embeddings[idx] = embedding
  90. return embeddings
  91. if __name__ == "__main__":
  92. model = DummyModel()
  93. manager = EmbeddingManager(model, 1024)
  94. print(manager.get_embedding(["hello"]))
  95. print(manager.get_embedding(["world"]))
  96. print(manager.get_embedding(["hello world"]))
  97. manager.save_now()
  98. print(manager.get_embedding(["new", "word"]))
  99. manager.save_now()