get_sim_k.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #coding utf-8
  2. import sys
  3. import pandas as pd
  4. import numpy as np
  5. import faiss
  6. import time
  7. def gen_i2i(index_item, embeddings,i2i):
  8. fw=open(i2i,"w")
  9. #print(i2i)
  10. embed_matrix=np.array(embeddings).astype('float32')
  11. #print(embed_matrix)
  12. index=faiss.IndexFlatL2(100)
  13. index.add(embed_matrix)
  14. #the candicate matrix is embed_matrix,but the search matrix is the same.
  15. #if the search vector is in the candicate matrix, the return idx>> the first is the search vector itself
  16. #if the search vector is not in the candicate matrix, the return idx>>the first is the index of the candicate
  17. distence_matrix,recall_list=index.search(embed_matrix, 20)
  18. for idx,rec_arr in enumerate(recall_list):
  19. #print("idx:", idx)
  20. orgin_item=str(index_item[idx])
  21. recall_str=""
  22. #rec_arr=[0 6 3 8 7 1]
  23. for re_id in rec_arr[1:]:
  24. recall_idstr=str(index_item[re_id])
  25. #print(recall_idstr)
  26. recall_str=recall_str+","+recall_idstr
  27. fw.write(orgin_item+"\t"+recall_str[1:]+"\n")
  28. if __name__ == '__main__':
  29. f = open(sys.argv[1])
  30. index = 0
  31. start_time = time.time()
  32. index_dict = {}
  33. index_arr = []
  34. while True:
  35. line = f.readline()
  36. if not line:
  37. break
  38. line = line.strip().replace("[","").replace("]","")
  39. #print(eval(line))
  40. items = line.split(" ")
  41. if len(items)<2:
  42. continue
  43. try:
  44. vid = int(items[0])
  45. #vid_vec = items[1:]
  46. print(line.split(" "))
  47. vid_vec = eval(" ".join(items[1:]))
  48. index_arr.append(vid_vec)
  49. index_dict[index] = vid
  50. index +=1
  51. except:
  52. continue
  53. f.close()
  54. print(len(index_arr))
  55. end_time = time.time()
  56. print("time:", (end_time-start_time))
  57. #gen_i2i(index_dict, index_arr, "i2i_result")