get_batch_sim_k.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #coding utf-8
  2. import sys
  3. import pandas as pd
  4. import numpy as np
  5. import faiss
  6. import time
  7. def gen_i2i(index_item, embeddings,i2i):
  8. fw=open(i2i,"w")
  9. #print(i2i)
  10. start_time = time.time()
  11. #xb = embeddings
  12. xb=np.array(embeddings).astype('float32')
  13. #print(xb)
  14. #index.add(xb)
  15. dim, measure = 64, faiss.METRIC_L2
  16. param = 'IVF100,PQ16'
  17. index = faiss.index_factory(dim, param, measure)
  18. #print(index.is_trained) # 此时输出为False,因为倒排索引需要训练k-means,
  19. index.train(xb)
  20. end_time = time.time()
  21. print("time:", (end_time-start_time))
  22. #index=faiss.IndexFlatL2(100)
  23. #index.add(embed_matrix)
  24. #the candicate matrix is embed_matrix,but the search matrix is the same.
  25. #if the search vector is in the candicate matrix, the return idx>> the first is the search vector itself
  26. #if the search vector is not in the candicate matrix, the return idx>>the first is the index of the candicate
  27. batch = 10000
  28. num = len(embeddings)
  29. per_rounds = int(num/batch)+1
  30. #index=faiss.IndexFlatL2(64)
  31. index.add(xb)
  32. print("cost time:", (end_time-start_time))
  33. #distence_matrix,recall_list=index.search(xb, 20)
  34. #print(distence_matrix)
  35. #print(recall_list)
  36. for i in range(per_rounds):
  37. per_embedding = xb[i:(i+1)*batch]
  38. #print(per_embedding)
  39. #print(len(per_embedding))
  40. distence_matrix,recall_list=index.search(per_embedding, 20)
  41. #print("distence_matrix:", distence_matrix)
  42. #print("recall_list:", recall_list)
  43. for idx,rec_arr in enumerate(recall_list):
  44. #print("idx:", idx)
  45. orgin_item=str(index_item[idx])
  46. #print("orgin_item:", orgin_item)
  47. #print("rec_arr:", rec_arr)
  48. recall_str=""
  49. for re_id in rec_arr[1:]:
  50. if re_id in index_item:
  51. recall_idstr=str(index_item[re_id])
  52. recall_str=recall_str+","+recall_idstr
  53. fw.write(orgin_item+"\t"+recall_str[1:]+"\n")
  54. if __name__ == '__main__':
  55. f = open(sys.argv[1])
  56. index = 0
  57. index_dict = {}
  58. index_arr = []
  59. while True:
  60. line = f.readline()
  61. if not line:
  62. break
  63. items = line.strip().split(" ")
  64. try:
  65. vid = int(items[0])
  66. vid_vec = eval(" ".join(items[1:]))
  67. vid_vec=np.array(vid_vec)
  68. float_arr = vid_vec.astype(np.float64).tolist()
  69. #print(float_arr)
  70. index_arr.append(float_arr)
  71. #index +=1
  72. index_dict[index] = vid
  73. index +=1
  74. #break
  75. #print(index_arr)
  76. except:
  77. #break
  78. continue
  79. f.close()
  80. #print(index_arr)
  81. gen_i2i(index_dict, index_arr, "i2i_result")