batch_test.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. '''
  2. Created on Oct 10, 2018
  3. Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:
  4. Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.
  5. @author: Xiang Wang (xiangwang@u.nus.edu)
  6. '''
  7. from utility.parser import parse_args
  8. from utility.load_data import *
  9. from evaluator import eval_score_matrix_foldout
  10. import multiprocessing
  11. import heapq
  12. import numpy as np
  13. cores = multiprocessing.cpu_count() // 2
  14. args = parse_args()
  15. data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
  16. USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
  17. N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
  18. BATCH_SIZE = args.batch_size
  19. def test(sess, model, users_to_test, drop_flag=False, train_set_flag=0):
  20. # B: batch size
  21. # N: the number of items
  22. top_show = np.sort(model.Ks)
  23. max_top = max(top_show)
  24. result = {'precision': np.zeros(len(model.Ks)), 'recall': np.zeros(len(model.Ks)), 'ndcg': np.zeros(len(model.Ks))}
  25. u_batch_size = BATCH_SIZE
  26. test_users = users_to_test
  27. n_test_users = len(test_users)
  28. n_user_batchs = n_test_users // u_batch_size + 1
  29. count = 0
  30. all_result = []
  31. item_batch = range(ITEM_NUM)
  32. for u_batch_id in range(n_user_batchs):
  33. start = u_batch_id * u_batch_size
  34. end = (u_batch_id + 1) * u_batch_size
  35. user_batch = test_users[start: end]
  36. if drop_flag == False:
  37. rate_batch = sess.run(model.batch_ratings, {model.users: user_batch,
  38. model.pos_items: item_batch})
  39. else:
  40. rate_batch = sess.run(model.batch_ratings, {model.users: user_batch,
  41. model.pos_items: item_batch,
  42. model.node_dropout: [0.] * len(eval(args.layer_size)),
  43. model.mess_dropout: [0.] * len(eval(args.layer_size))})
  44. rate_batch = np.array(rate_batch)# (B, N)
  45. test_items = []
  46. if train_set_flag == 0:
  47. for user in user_batch:
  48. test_items.append(data_generator.test_set[user])# (B, #test_items)
  49. # set the ranking scores of training items to -inf,
  50. # then the training items will be sorted at the end of the ranking list.
  51. for idx, user in enumerate(user_batch):
  52. train_items_off = data_generator.train_items[user]
  53. rate_batch[idx][train_items_off] = -np.inf
  54. else:
  55. for user in user_batch:
  56. test_items.append(data_generator.train_items[user])
  57. batch_result = eval_score_matrix_foldout(rate_batch, test_items, max_top)#(B,k*metric_num), max_top= 20
  58. count += len(batch_result)
  59. all_result.append(batch_result)
  60. assert count == n_test_users
  61. all_result = np.concatenate(all_result, axis=0)
  62. final_result = np.mean(all_result, axis=0) # mean
  63. final_result = np.reshape(final_result, newshape=[5, max_top])
  64. final_result = final_result[:, top_show-1]
  65. final_result = np.reshape(final_result, newshape=[5, len(top_show)])
  66. result['precision'] += final_result[0]
  67. result['recall'] += final_result[1]
  68. result['ndcg'] += final_result[3]
  69. return result