#coding utf-8 import sys import pandas as pd import numpy as np import faiss import time def gen_i2i(index_item, embeddings,i2i): fw=open(i2i,"w") #print(i2i) start_time = time.time() #xb = embeddings xb=np.array(embeddings).astype('float32') #print(xb) #index.add(xb) dim, measure = 64, faiss.METRIC_L2 param = 'IVF100,PQ16' index = faiss.index_factory(dim, param, measure) #print(index.is_trained) # 此时输出为False,因为倒排索引需要训练k-means, index.train(xb) end_time = time.time() print("time:", (end_time-start_time)) #index=faiss.IndexFlatL2(100) #index.add(embed_matrix) #the candicate matrix is embed_matrix,but the search matrix is the same. #if the search vector is in the candicate matrix, the return idx>> the first is the search vector itself #if the search vector is not in the candicate matrix, the return idx>>the first is the index of the candicate batch = 10000 num = len(embeddings) per_rounds = int(num/batch)+1 #index=faiss.IndexFlatL2(64) index.add(xb) print("cost time:", (end_time-start_time)) #distence_matrix,recall_list=index.search(xb, 20) #print(distence_matrix) #print(recall_list) for i in range(per_rounds): per_embedding = xb[i:(i+1)*batch] #print(per_embedding) #print(len(per_embedding)) distence_matrix,recall_list=index.search(per_embedding, 20) #print("distence_matrix:", distence_matrix) #print("recall_list:", recall_list) for idx,rec_arr in enumerate(recall_list): #print("idx:", idx) orgin_item=str(index_item[idx]) #print("orgin_item:", orgin_item) #print("rec_arr:", rec_arr) recall_str="" for re_id in rec_arr[1:]: if re_id in index_item: recall_idstr=str(index_item[re_id]) recall_str=recall_str+","+recall_idstr fw.write(orgin_item+"\t"+recall_str[1:]+"\n") if __name__ == '__main__': f = open(sys.argv[1]) index = 0 index_dict = {} index_arr = [] while True: line = f.readline() if not line: break items = line.strip().split(" ") try: vid = int(items[0]) vid_vec = eval(" ".join(items[1:])) vid_vec=np.array(vid_vec) float_arr = vid_vec.astype(np.float64).tolist() #print(float_arr) index_arr.append(float_arr) #index +=1 index_dict[index] = vid index +=1 #break #print(index_arr) except: #break continue f.close() #print(index_arr) gen_i2i(index_dict, index_arr, "i2i_result")