embedding_manager.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # encoding:utf-8
  2. import pandas as pd
  3. import json
  4. import numpy as np
  5. import faiss
  6. import time
  7. import logging
  8. class EmbeddingManager(object):
  9. def __init__(self, fpath, key_name, value_name):
  10. begin_time = time.time()
  11. # pandas.dataframe
  12. self.df = pd.read_csv(fpath)
  13. read_time = time.time()
  14. logging.info("read csv embedding file cost time is: " + str(read_time - begin_time))
  15. # 将文件中的embedding加载到内存
  16. self.dict_embedding = self.load_embedding_to_dict(key_name, value_name)
  17. emb_time = time.time()
  18. logging.info("load embedding to dict cost time is: " + str(emb_time - read_time))
  19. # 在faiss建立索引
  20. self.faiss_index = self.load_embedding_to_faiss(key_name, value_name)
  21. logging.info("load embedding to faiss cost time is: " + str(time.time() - emb_time))
  22. def get_embedding(self, key):
  23. if str(key) in self.dict_embedding.keys():
  24. return self.dict_embedding[str(key)]
  25. else:
  26. return ""
  27. def load_embedding_to_dict(self, key_name, value_name):
  28. return {
  29. str(row[key_name]): row[value_name]
  30. for index, row in self.df.iterrows()
  31. }
  32. def load_embedding_to_faiss(self, key_name, value_name):
  33. # id列表
  34. ids = self.df[key_name].values.astype(np.int64)
  35. logging.info("ids is: ")
  36. print(ids)
  37. # 二维embedding
  38. # datas = [json.loads(x[1:-1].strip('\n').split()) for x in self.df[value_name]]
  39. datas = [x[1:-1].strip('\n').split() for x in self.df[value_name]]
  40. datas = np.array(datas).astype(np.float32)
  41. # 维度
  42. dimension = datas.shape[1]
  43. # 创建faiss索引
  44. # index = faiss.IndexFlatL2(dimension)
  45. index = faiss.IndexFlatIP(dimension) # 点乘,归一化的向量点乘即cosine相似度(越大越好)
  46. index2 = faiss.IndexIDMap(index)
  47. index2.add_with_ids(datas, ids)
  48. return index2
  49. def search_ids_by_embedding(self, embedding_str, topk):
  50. """实现近邻搜索"""
  51. begin_time = time.time()
  52. input = np.array(json.loads(embedding_str))
  53. input = np.expand_dims(input, axis=0).astype(np.float32)
  54. D, I = self.faiss_index.search(input, topk)
  55. logging.info("search ids by vid embedding cost time is: " + str(time.time() - begin_time))
  56. return list(I[0])
  57. def search_ids_by_embedding_list(self, embedding_str_list, topk):
  58. """实现近邻搜索"""
  59. begin_time = time.time()
  60. logging.info("embedding_str_list len is: " + str(len(embedding_str_list)))
  61. # input = np.array(json.loads(embedding_str))
  62. input = np.array(embedding_str_list).astype(np.float32)
  63. # input = np.expand_dims(input, axis=0).astype(np.float32)
  64. D, I = self.faiss_index.search(input, topk)
  65. res_list = list()
  66. for arr in I:
  67. res_list.append(list(arr))
  68. logging.info("search ids by vid list embedding cost time is: " + str(time.time() - begin_time))
  69. return res_list