123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # encoding:utf-8
- import pandas as pd
- import json
- import numpy as np
- import faiss
- import time
- import logging
- class EmbeddingManager(object):
- def __init__(self, fpath, key_name, value_name):
- begin_time = time.time()
- # pandas.dataframe
- self.df = pd.read_csv(fpath)
- read_time = time.time()
- logging.info("read csv embedding file cost time is: " + str(read_time - begin_time))
- # 将文件中的embedding加载到内存
- self.dict_embedding = self.load_embedding_to_dict(key_name, value_name)
- emb_time = time.time()
- logging.info("load embedding to dict cost time is: " + str(emb_time - read_time))
- # 在faiss建立索引
- self.faiss_index = self.load_embedding_to_faiss(key_name, value_name)
- logging.info("load embedding to faiss cost time is: " + str(time.time() - emb_time))
- def get_embedding(self, key):
- if str(key) in self.dict_embedding.keys():
- return self.dict_embedding[str(key)]
- else:
- return ""
- def load_embedding_to_dict(self, key_name, value_name):
- return {
- str(row[key_name]): row[value_name]
- for index, row in self.df.iterrows()
- }
- def load_embedding_to_faiss(self, key_name, value_name):
- # id列表
- ids = self.df[key_name].values.astype(np.int64)
- logging.info("ids is: ")
- print(ids)
- # 二维embedding
- # datas = [json.loads(x[1:-1].strip('\n').split()) for x in self.df[value_name]]
- datas = [x[1:-1].strip('\n').split() for x in self.df[value_name]]
- datas = np.array(datas).astype(np.float32)
- # 维度
- dimension = datas.shape[1]
- # 创建faiss索引
- # index = faiss.IndexFlatL2(dimension)
- index = faiss.IndexFlatIP(dimension) # 点乘,归一化的向量点乘即cosine相似度(越大越好)
- index2 = faiss.IndexIDMap(index)
- index2.add_with_ids(datas, ids)
- return index2
- def search_ids_by_embedding(self, embedding_str, topk):
- """实现近邻搜索"""
- begin_time = time.time()
- input = np.array(json.loads(embedding_str))
- input = np.expand_dims(input, axis=0).astype(np.float32)
- D, I = self.faiss_index.search(input, topk)
- logging.info("search ids by vid embedding cost time is: " + str(time.time() - begin_time))
- return list(I[0])
- def search_ids_by_embedding_list(self, embedding_str_list, topk):
- """实现近邻搜索"""
- begin_time = time.time()
- logging.info("embedding_str_list len is: " + str(len(embedding_str_list)))
- # input = np.array(json.loads(embedding_str))
- input = np.array(embedding_str_list).astype(np.float32)
- # input = np.expand_dims(input, axis=0).astype(np.float32)
- D, I = self.faiss_index.search(input, topk)
- res_list = list()
- for arr in I:
- res_list.append(list(arr))
- logging.info("search ids by vid list embedding cost time is: " + str(time.time() - begin_time))
- return res_list
|