get_sim_k.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #coding utf-8
  2. import sys
  3. import pandas as pd
  4. import numpy as np
  5. import faiss
  6. def gen_i2i(index_item, embeddings,i2i):
  7. fw=open(i2i,"w")
  8. #print(i2i)
  9. embed_matrix=np.array(embeddings).astype('float32')
  10. #print(embed_matrix)
  11. index=faiss.IndexFlatL2(100)
  12. index.add(embed_matrix)
  13. #the candicate matrix is embed_matrix,but the search matrix is the same.
  14. #if the search vector is in the candicate matrix, the return idx>> the first is the search vector itself
  15. #if the search vector is not in the candicate matrix, the return idx>>the first is the index of the candicate
  16. distence_matrix,recall_list=index.search(embed_matrix, 20)
  17. for idx,rec_arr in enumerate(recall_list):
  18. #print("idx:", idx)
  19. orgin_item=str(index_item[idx])
  20. recall_str=""
  21. #rec_arr=[0 6 3 8 7 1]
  22. for re_id in rec_arr[1:]:
  23. recall_idstr=str(index_item[re_id])
  24. #print(recall_idstr)
  25. recall_str=recall_str+","+recall_idstr
  26. fw.write(orgin_item+"\t"+recall_str[1:]+"\n")
  27. if __name__ == '__main__':
  28. f = open(sys.argv[1])
  29. index = 0
  30. index_dict = {}
  31. index_arr = []
  32. while True:
  33. line = f.readline()
  34. if not line:
  35. break
  36. line = line.strip()
  37. #print(eval(line))
  38. items = line.split(" ")
  39. try:
  40. vid = int(items[0])
  41. vid_vec = eval(" ".join(items[1:]))
  42. index_arr.append(vid_vec)
  43. #index +=1
  44. index_dict[index] = vid
  45. index +=1
  46. #print(index_arr)
  47. except:
  48. continue
  49. f.close()
  50. print(len(index_arr))
  51. gen_i2i(index_dict, index_arr, "i2i_result")