embedding_manager.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import threading
  3. import multiprocessing
  4. from time import sleep
  5. import numpy as np
  6. class EmbeddingManager:
  7. def __init__(self, model, cache_file="embedding_cache", save_interval=600):
  8. self.model = model
  9. self.cache_file = cache_file
  10. self.cache_key_file = f'{self.cache_file}.keys'
  11. self.save_interval = save_interval
  12. self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
  13. self.lock = threading.Lock() # Thread-safe lock
  14. npy_filename = f'{self.cache_file}.npy'
  15. # Load cache from file if it exists
  16. if os.path.exists(npy_filename):
  17. embedding_data = np.load(npy_filename)
  18. embedding_keys = open(self.cache_key_file, "r").readlines()
  19. embedding_keys = [key.strip("\n") for key in embedding_keys]
  20. for idx, key in enumerate(embedding_keys):
  21. self.cache[key] = embedding_data[idx]
  22. print("cache loaded:")
  23. print(self.cache)
  24. # Start the periodic saving thread
  25. self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
  26. self.saving_thread.start()
  27. def get_embedding(self, text_list):
  28. """
  29. Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
  30. """
  31. if not isinstance(text_list, list):
  32. raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
  33. embedding_list = np.zeros((len(text_list), 1024), np.float32)
  34. if not text_list:
  35. return embedding_list
  36. new_texts = []
  37. new_texts_ori_idx = []
  38. for idx, text in enumerate(text_list):
  39. if text in self.cache:
  40. print(f"find {text} in cache")
  41. embedding_list[idx] = self.cache[text]
  42. else:
  43. new_texts.append(text)
  44. new_texts_ori_idx.append(idx)
  45. new_embeddings = self.model.get_embedding(new_texts)
  46. # Generate embedding if not found in cache
  47. with self.lock: # Ensure thread-safe access
  48. for idx, text in enumerate(new_texts):
  49. if text not in self.cache: # Re-check in case another thread added it
  50. self.cache[text] = new_embeddings[idx]
  51. embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
  52. return embedding_list
  53. def _periodic_save(self):
  54. """Periodically save the cache to disk."""
  55. while True:
  56. sleep(self.save_interval)
  57. self.save_now()
  58. def save_now(self):
  59. """Manually trigger a save to disk."""
  60. # TODO: double-buffer
  61. with self.lock: # Ensure thread-safe access
  62. keys = self.cache.keys()
  63. cache_to_save = np.zeros((len(keys), 1024), np.float32)
  64. for idx, key in enumerate(keys):
  65. cache_to_save[idx] = self.cache[key]
  66. np.save(self.cache_file, cache_to_save)
  67. with open(self.cache_key_file, 'w') as fp:
  68. fp.write('\n'.join(keys))
  69. # Only for testing
  70. class DummyModel:
  71. def padding_text(self, text):
  72. padding_factor = 1024 // len(text)
  73. text = text * padding_factor
  74. text += text[:1024 - len(text)]
  75. return text
  76. def get_embedding(self, text_list):
  77. embeddings = np.zeros((len(text_list), 1024), np.float32)
  78. for idx, text in enumerate(text_list):
  79. text = self.padding_text(text)
  80. embedding = np.array([ord(c) for c in text], np.float32)
  81. embeddings[idx] = embedding
  82. return embeddings
  83. if __name__ == "__main__":
  84. model = DummyModel()
  85. manager = EmbeddingManager(model)
  86. print(manager.get_embedding(["hello"]))
  87. print(manager.get_embedding(["world"]))
  88. print(manager.get_embedding(["hello world"]))
  89. manager.save_now()