# encoding:utf-8 import pandas as pd import json import numpy as np import faiss import time import logging class EmbeddingManager(object): def __init__(self, fpath, key_name, value_name): begin_time = time.time() # pandas.dataframe self.df = pd.read_csv(fpath) read_time = time.time() logging.info("read csv embedding file cost time is: " + str(read_time - begin_time)) # 将文件中的embedding加载到内存 self.dict_embedding = self.load_embedding_to_dict(key_name, value_name) emb_time = time.time() logging.info("load embedding to dict cost time is: " + str(emb_time - read_time)) # 在faiss建立索引 self.faiss_index = self.load_embedding_to_faiss(key_name, value_name) logging.info("load embedding to faiss cost time is: " + str(time.time() - emb_time)) def get_embedding(self, key): if str(key) in self.dict_embedding.keys(): return self.dict_embedding[str(key)] else: return "" def load_embedding_to_dict(self, key_name, value_name): return { str(row[key_name]): row[value_name] for index, row in self.df.iterrows() } def load_embedding_to_faiss(self, key_name, value_name): # id列表 ids = self.df[key_name].values.astype(np.int64) logging.info("ids is: ") print(ids) # 二维embedding # datas = [json.loads(x[1:-1].strip('\n').split()) for x in self.df[value_name]] datas = [x[1:-1].strip('\n').split() for x in self.df[value_name]] datas = np.array(datas).astype(np.float32) # 维度 dimension = datas.shape[1] # 创建faiss索引 # index = faiss.IndexFlatL2(dimension) index = faiss.IndexFlatIP(dimension) # 点乘,归一化的向量点乘即cosine相似度(越大越好) index2 = faiss.IndexIDMap(index) index2.add_with_ids(datas, ids) return index2 def search_ids_by_embedding(self, embedding_str, topk): """实现近邻搜索""" begin_time = time.time() input = np.array(json.loads(embedding_str)) input = np.expand_dims(input, axis=0).astype(np.float32) D, I = self.faiss_index.search(input, topk) logging.info("search ids by vid embedding cost time is: " + str(time.time() - begin_time)) return list(I[0]) def search_ids_by_embedding_list(self, embedding_str_list, topk): """实现近邻搜索""" begin_time = time.time() logging.info("embedding_str_list len is: " + str(len(embedding_str_list))) # input = np.array(json.loads(embedding_str)) input = np.array(embedding_str_list).astype(np.float32) # input = np.expand_dims(input, axis=0).astype(np.float32) D, I = self.faiss_index.search(input, topk) res_list = list() for arr in I: res_list.append(list(arr)) logging.info("search ids by vid list embedding cost time is: " + str(time.time() - begin_time)) return res_list