embedding_manager.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # encoding: utf-8
  2. import pandas as pd
  3. import json
  4. import numpy as np
  5. import faiss
  6. import time
  7. import logging
  8. class EmbeddingManager(object):
  9. def __init__(self, fpath, key_name, value_name):
  10. begin_time = time.time()
  11. # pandas.dataframe
  12. self.df = pd.read_csv(fpath)
  13. read_time = time.time()
  14. logging.info("read csv embedding file cost time is: " + str(read_time - begin_time))
  15. # 将文件中的embedding加载到内存
  16. self.dict_embedding = self.load_embedding_to_dict(key_name, value_name)
  17. emb_time = time.time()
  18. logging.info("load embedding to dict cost time is: " + str(emb_time - read_time))
  19. # 在faiss建立索引
  20. self.faiss_index = self.load_embedding_to_faiss(key_name, value_name)
  21. logging.info("load embedding to faiss cost time is: " + str(time.time() - emb_time))
  22. def get_embedding(self, key):
  23. if str(key) in self.dict_embedding.keys():
  24. return self.dict_embedding[str(key)]
  25. else:
  26. return ""
  27. def load_embedding_to_dict(self, key_name, value_name):
  28. # for index, row in self.df.iterrows():
  29. # print("row-- " + key_name + " : " + str(row[key_name]))
  30. # print("row-- " + value_name + " : " + str(row[value_name]))
  31. return {
  32. str(row[key_name]): row[value_name]
  33. for index, row in self.df.iterrows()
  34. }
  35. def load_embedding_to_faiss(self, key_name, value_name):
  36. # id列表
  37. ids = self.df[key_name].values.astype(np.int64)
  38. logging.info("ids is: ")
  39. print(ids)
  40. # 二维embedding
  41. datas = [json.loads(x) for x in self.df[value_name]]
  42. datas = np.array(datas).astype(np.float32)
  43. # 维度
  44. dimension = datas.shape[1]
  45. # 创建faiss索引
  46. # index = faiss.IndexFlatL2(dimension)
  47. index = faiss.IndexFlatIP(dimension) # 点乘,归一化的向量点乘即cosine相似度(越大越好)
  48. index2 = faiss.IndexIDMap(index)
  49. index2.add_with_ids(datas, ids)
  50. return index2
  51. def search_ids_by_embedding(self, embedding_str, topk):
  52. """实现近邻搜索"""
  53. begin_time = time.time()
  54. input = np.array(json.loads(embedding_str))
  55. input = np.expand_dims(input, axis=0).astype(np.float32)
  56. D, I = self.faiss_index.search(input, topk)
  57. logging.info("search ids by vid embedding cost time is: " + str(time.time() - begin_time))
  58. return list(I[0])
  59. def search_ids_by_embedding_list(self, embedding_str_list, topk):
  60. """实现近邻搜索"""
  61. begin_time = time.time()
  62. logging.info("embedding_str_list len is: " + str(len(embedding_str_list)))
  63. # print(embedding_str_list)
  64. # input = np.array(json.loads(embedding_str))
  65. input = np.array(embedding_str_list).astype(np.float32)
  66. # input = np.expand_dims(input, axis=0).astype(np.float32)
  67. D, I = self.faiss_index.search(input, topk)
  68. res_list = list()
  69. for arr in I:
  70. res_list.append(list(arr))
  71. logging.info("search ids by vid list embedding cost time is: " + str(time.time() - begin_time))
  72. return res_list