|
@@ -6,8 +6,9 @@ import numpy as np
|
|
|
|
|
|
|
|
|
|
class EmbeddingManager:
|
|
class EmbeddingManager:
|
|
- def __init__(self, model, cache_file="embedding_cache", save_interval=600):
|
|
|
|
|
|
+ def __init__(self, model, emb_size = 1024, cache_file="embedding_cache", save_interval=600):
|
|
self.model = model
|
|
self.model = model
|
|
|
|
+ self.emb_size = emb_size
|
|
self.cache_file = cache_file
|
|
self.cache_file = cache_file
|
|
self.cache_key_file = f'{self.cache_file}.keys'
|
|
self.cache_key_file = f'{self.cache_file}.keys'
|
|
self.save_interval = save_interval
|
|
self.save_interval = save_interval
|
|
@@ -35,7 +36,7 @@ class EmbeddingManager:
|
|
"""
|
|
"""
|
|
if not isinstance(text_list, list):
|
|
if not isinstance(text_list, list):
|
|
raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
|
|
raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
|
|
- embedding_list = np.zeros((len(text_list), 1024), np.float32)
|
|
|
|
|
|
+ embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
|
|
if not text_list:
|
|
if not text_list:
|
|
return embedding_list
|
|
return embedding_list
|
|
new_texts = []
|
|
new_texts = []
|
|
@@ -70,14 +71,16 @@ class EmbeddingManager:
|
|
tmp_cache_key_file = self.cache_key_file + ".tmp"
|
|
tmp_cache_key_file = self.cache_key_file + ".tmp"
|
|
with self.lock: # Ensure thread-safe access
|
|
with self.lock: # Ensure thread-safe access
|
|
keys = self.cache.keys()
|
|
keys = self.cache.keys()
|
|
- cache_to_save = np.zeros((len(keys), 1024), np.float32)
|
|
|
|
|
|
+ cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
|
|
for idx, key in enumerate(keys):
|
|
for idx, key in enumerate(keys):
|
|
cache_to_save[idx] = self.cache[key]
|
|
cache_to_save[idx] = self.cache[key]
|
|
np.save(tmp_cache_file, cache_to_save)
|
|
np.save(tmp_cache_file, cache_to_save)
|
|
with open(tmp_cache_key_file, 'w') as fp:
|
|
with open(tmp_cache_key_file, 'w') as fp:
|
|
fp.write('\n'.join(keys))
|
|
fp.write('\n'.join(keys))
|
|
- os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
|
|
|
|
- os.rename(self.cache_key_file, self.cache_key_file + ".bak")
|
|
|
|
|
|
+ if os.path.exists(self.cache_file + ".npy"):
|
|
|
|
+ os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
|
|
|
|
+ if os.path.exists(self.cache_key_file):
|
|
|
|
+ os.rename(self.cache_key_file, self.cache_key_file + ".bak")
|
|
os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
|
|
os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
|
|
os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
|
|
|
|
@@ -100,7 +103,7 @@ class DummyModel:
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
model = DummyModel()
|
|
model = DummyModel()
|
|
- manager = EmbeddingManager(model)
|
|
|
|
|
|
+ manager = EmbeddingManager(model, 1024)
|
|
print(manager.get_embedding(["hello"]))
|
|
print(manager.get_embedding(["hello"]))
|
|
print(manager.get_embedding(["world"]))
|
|
print(manager.get_embedding(["world"]))
|
|
print(manager.get_embedding(["hello world"]))
|
|
print(manager.get_embedding(["hello world"]))
|