1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- #coding utf-8
- import sys
- import pandas as pd
- import numpy as np
- import faiss
- import time
- def gen_i2i(index_item, embeddings,i2i):
- fw=open(i2i,"w")
- #print(i2i)
- start_time = time.time()
- #xb = embeddings
- xb=np.array(embeddings).astype('float32')
- #print(xb)
- #index.add(xb)
- dim, measure = 64, faiss.METRIC_L2
- param = 'IVF100,PQ16'
- index = faiss.index_factory(dim, param, measure)
- #print(index.is_trained) # 此时输出为False,因为倒排索引需要训练k-means,
- index.train(xb)
- end_time = time.time()
- print("time:", (end_time-start_time))
- #index=faiss.IndexFlatL2(100)
- #index.add(embed_matrix)
- #the candicate matrix is embed_matrix,but the search matrix is the same.
- #if the search vector is in the candicate matrix, the return idx>> the first is the search vector itself
- #if the search vector is not in the candicate matrix, the return idx>>the first is the index of the candicate
- batch = 10000
- num = len(embeddings)
- per_rounds = int(num/batch)+1
- #index=faiss.IndexFlatL2(64)
- index.add(xb)
- print("cost time:", (end_time-start_time))
- #distence_matrix,recall_list=index.search(xb, 20)
- #print(distence_matrix)
- #print(recall_list)
- for i in range(per_rounds):
- per_embedding = xb[i:(i+1)*batch]
- #print(per_embedding)
- #print(len(per_embedding))
- distence_matrix,recall_list=index.search(per_embedding, 20)
- #print("distence_matrix:", distence_matrix)
- #print("recall_list:", recall_list)
- for idx,rec_arr in enumerate(recall_list):
- #print("idx:", idx)
- orgin_item=str(index_item[idx])
- #print("orgin_item:", orgin_item)
- #print("rec_arr:", rec_arr)
- recall_str=""
- for re_id in rec_arr[1:]:
- if re_id in index_item:
- recall_idstr=str(index_item[re_id])
- recall_str=recall_str+","+recall_idstr
- fw.write(orgin_item+"\t"+recall_str[1:]+"\n")
- if __name__ == '__main__':
- f = open(sys.argv[1])
- index = 0
- index_dict = {}
- index_arr = []
- while True:
- line = f.readline()
- if not line:
- break
- items = line.strip().split(" ")
- try:
- vid = int(items[0])
- vid_vec = eval(" ".join(items[1:]))
- vid_vec=np.array(vid_vec)
- float_arr = vid_vec.astype(np.float64).tolist()
- #print(float_arr)
- index_arr.append(float_arr)
- #index +=1
- index_dict[index] = vid
- index +=1
- #break
- #print(index_arr)
- except:
- #break
- continue
- f.close()
- #print(index_arr)
- gen_i2i(index_dict, index_arr, "i2i_result")
|