load_data.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. '''
  2. Created on Oct 10, 2018
  3. Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:
  4. Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.
  5. @author: Xiang Wang (xiangwang@u.nus.edu)
  6. '''
  7. import numpy as np
  8. import random as rd
  9. import scipy.sparse as sp
  10. from time import time
  11. class Data(object):
  12. def __init__(self, path, batch_size):
  13. self.path = path
  14. self.batch_size = batch_size
  15. train_file = path + '/train.txt'
  16. test_file = path + '/test.txt'
  17. self.n_users, self.n_items = 0, 0
  18. self.n_train, self.n_test = 0, 0
  19. self.neg_pools = {}
  20. self.exist_users = []
  21. with open(train_file) as f:
  22. for l in f.readlines():
  23. if len(l) > 0:
  24. l = l.strip('\n').split(' ')
  25. #items = [int(i) for i in l[1:]]
  26. items = list()
  27. for idx , i in enumerate(l[1:]):
  28. if idx == 0:
  29. val = i.strip('\"')
  30. print("begin val: " + val)
  31. items.append(int(val))
  32. elif idx == (len(l[1:]) - 1):
  33. val = i.strip('\"')
  34. print("end val: " + val)
  35. items.append(int(val))
  36. else:
  37. items.append(int(val))
  38. uid = int(l[0])
  39. self.exist_users.append(uid)
  40. self.n_items = max(self.n_items, max(items))
  41. self.n_users = max(self.n_users, uid)
  42. self.n_train += len(items)
  43. with open(test_file) as f:
  44. for l in f.readlines():
  45. if len(l) > 0:
  46. #l = l.strip('\n')
  47. try:
  48. #items = [int(i) for i in l.split(' ')[1:]]
  49. l = l.strip('\n').split(' ')
  50. #items = [int(i) for i in l[1:]]
  51. items = list()
  52. for idx , i in enumerate(l[1:]):
  53. if idx == 0:
  54. val = i.strip('\"')
  55. print("begin val: " + val)
  56. items.append(int(val))
  57. elif idx == (len(l[1:]) - 1):
  58. val = i.strip('\"')
  59. print("end val: " + val)
  60. items.append(int(val))
  61. else:
  62. items.append(int(val))
  63. except Exception:
  64. continue
  65. self.n_items = max(self.n_items, max(items))
  66. self.n_test += len(items)
  67. self.n_items += 1
  68. self.n_users += 1
  69. self.print_statistics()
  70. count_size = max(self.n_items, self.n_users)
  71. self.R = sp.dok_matrix((count_size, count_size), dtype=np.float32)
  72. self.train_items, self.test_set = {}, {}
  73. with open(train_file) as f_train:
  74. with open(test_file) as f_test:
  75. for l in f_train.readlines():
  76. if len(l) == 0: break
  77. l = l.strip('\n')
  78. items = [int(i.strip('\"')) for i in l.split(' ')]
  79. uid, train_items = items[0], items[1:]
  80. for i in train_items:
  81. self.R[uid, i] = 1.
  82. self.train_items[uid] = train_items
  83. for l in f_test.readlines():
  84. if len(l) == 0: break
  85. l = l.strip('\n')
  86. try:
  87. items = [int(i.strip('\"')) for i in l.split(' ')]
  88. except Exception:
  89. continue
  90. uid, test_items = items[0], items[1:]
  91. self.test_set[uid] = test_items
  92. def get_adj_mat(self):
  93. try:
  94. t1 = time()
  95. adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
  96. norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
  97. mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
  98. print('already load adj matrix', adj_mat.shape, time() - t1)
  99. except Exception:
  100. adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
  101. sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)
  102. sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)
  103. sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)
  104. try:
  105. pre_adj_mat = sp.load_npz(self.path + '/s_pre_adj_mat.npz')
  106. except Exception:
  107. adj_mat=adj_mat
  108. rowsum = np.array(adj_mat.sum(1))
  109. d_inv = np.power(rowsum, -0.5).flatten()
  110. d_inv[np.isinf(d_inv)] = 0.
  111. d_mat_inv = sp.diags(d_inv)
  112. norm_adj = d_mat_inv.dot(adj_mat)
  113. norm_adj = norm_adj.dot(d_mat_inv)
  114. print('generate pre adjacency matrix.')
  115. pre_adj_mat = norm_adj.tocsr()
  116. sp.save_npz(self.path + '/s_pre_adj_mat.npz', norm_adj)
  117. return adj_mat, norm_adj_mat, mean_adj_mat,pre_adj_mat
  118. def create_adj_mat(self):
  119. t1 = time()
  120. adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
  121. adj_mat = adj_mat.tolil()
  122. R = self.R.tolil()
  123. # prevent memory from overflowing
  124. for i in range(5):
  125. adj_mat[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5), self.n_users:] =\
  126. R[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)]
  127. adj_mat[self.n_users:,int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)] =\
  128. R[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)].T
  129. adj_mat = adj_mat.todok()
  130. print('already create adjacency matrix', adj_mat.shape, time() - t1)
  131. t2 = time()
  132. def normalized_adj_single(adj):
  133. rowsum = np.array(adj.sum(1))
  134. d_inv = np.power(rowsum, -1).flatten()
  135. d_inv[np.isinf(d_inv)] = 0.
  136. d_mat_inv = sp.diags(d_inv)
  137. norm_adj = d_mat_inv.dot(adj)
  138. print('generate single-normalized adjacency matrix.')
  139. return norm_adj.tocoo()
  140. def check_adj_if_equal(adj):
  141. dense_A = np.array(adj.todense())
  142. degree = np.sum(dense_A, axis=1, keepdims=False)
  143. temp = np.dot(np.diag(np.power(degree, -1)), dense_A)
  144. print('check normalized adjacency matrix whether equal to this laplacian matrix.')
  145. return temp
  146. norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))
  147. mean_adj_mat = normalized_adj_single(adj_mat)
  148. print('already normalize adjacency matrix', time() - t2)
  149. return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()
  150. def negative_pool(self):
  151. t1 = time()
  152. for u in self.train_items.keys():
  153. neg_items = list(set(range(self.n_items)) - set(self.train_items[u]))
  154. pools = [rd.choice(neg_items) for _ in range(100)]
  155. self.neg_pools[u] = pools
  156. print('refresh negative pools', time() - t1)
  157. def sample(self):
  158. if self.batch_size <= self.n_users:
  159. users = rd.sample(self.exist_users, self.batch_size)
  160. else:
  161. users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
  162. def sample_pos_items_for_u(u, num):
  163. pos_items = self.train_items[u]
  164. n_pos_items = len(pos_items)
  165. pos_batch = []
  166. while True:
  167. if len(pos_batch) == num: break
  168. pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
  169. pos_i_id = pos_items[pos_id]
  170. if pos_i_id not in pos_batch:
  171. pos_batch.append(pos_i_id)
  172. return pos_batch
  173. def sample_neg_items_for_u(u, num):
  174. neg_items = []
  175. while True:
  176. if len(neg_items) == num: break
  177. neg_id = np.random.randint(low=0, high=self.n_items,size=1)[0]
  178. if neg_id not in self.train_items[u] and neg_id not in neg_items:
  179. neg_items.append(neg_id)
  180. return neg_items
  181. def sample_neg_items_for_u_from_pools(u, num):
  182. neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
  183. return rd.sample(neg_items, num)
  184. pos_items, neg_items = [], []
  185. for u in users:
  186. pos_items += sample_pos_items_for_u(u, 1)
  187. neg_items += sample_neg_items_for_u(u, 1)
  188. return users, pos_items, neg_items
  189. def sample_test(self):
  190. if self.batch_size <= self.n_users:
  191. users = rd.sample(self.test_set.keys(), self.batch_size)
  192. else:
  193. users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
  194. def sample_pos_items_for_u(u, num):
  195. pos_items = self.test_set[u]
  196. n_pos_items = len(pos_items)
  197. pos_batch = []
  198. while True:
  199. if len(pos_batch) == num: break
  200. pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
  201. pos_i_id = pos_items[pos_id]
  202. if pos_i_id not in pos_batch:
  203. pos_batch.append(pos_i_id)
  204. return pos_batch
  205. def sample_neg_items_for_u(u, num):
  206. neg_items = []
  207. while True:
  208. if len(neg_items) == num: break
  209. neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
  210. if neg_id not in (self.test_set[u]+self.train_items[u]) and neg_id not in neg_items:
  211. neg_items.append(neg_id)
  212. return neg_items
  213. def sample_neg_items_for_u_from_pools(u, num):
  214. neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
  215. return rd.sample(neg_items, num)
  216. pos_items, neg_items = [], []
  217. for u in users:
  218. pos_items += sample_pos_items_for_u(u, 1)
  219. neg_items += sample_neg_items_for_u(u, 1)
  220. return users, pos_items, neg_items
  221. def get_num_users_items(self):
  222. return self.n_users, self.n_items
  223. def print_statistics(self):
  224. print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
  225. print('n_interactions=%d' % (self.n_train + self.n_test))
  226. print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))
  227. def get_sparsity_split(self):
  228. try:
  229. split_uids, split_state = [], []
  230. lines = open(self.path + '/sparsity.split', 'r').readlines()
  231. for idx, line in enumerate(lines):
  232. if idx % 2 == 0:
  233. split_state.append(line.strip())
  234. print(line.strip())
  235. else:
  236. split_uids.append([int(uid) for uid in line.strip().split(' ')])
  237. print('get sparsity split.')
  238. except Exception:
  239. split_uids, split_state = self.create_sparsity_split()
  240. f = open(self.path + '/sparsity.split', 'w')
  241. for idx in range(len(split_state)):
  242. f.write(split_state[idx] + '\n')
  243. f.write(' '.join([str(uid) for uid in split_uids[idx]]) + '\n')
  244. print('create sparsity split.')
  245. return split_uids, split_state
  246. def create_sparsity_split(self):
  247. all_users_to_test = list(self.test_set.keys())
  248. user_n_iid = dict()
  249. # generate a dictionary to store (key=n_iids, value=a list of uid).
  250. for uid in all_users_to_test:
  251. train_iids = self.train_items[uid]
  252. test_iids = self.test_set[uid]
  253. n_iids = len(train_iids) + len(test_iids)
  254. if n_iids not in user_n_iid.keys():
  255. user_n_iid[n_iids] = [uid]
  256. else:
  257. user_n_iid[n_iids].append(uid)
  258. split_uids = list()
  259. # split the whole user set into four subset.
  260. temp = []
  261. count = 1
  262. fold = 4
  263. n_count = (self.n_train + self.n_test)
  264. n_rates = 0
  265. split_state = []
  266. for idx, n_iids in enumerate(sorted(user_n_iid)):
  267. temp += user_n_iid[n_iids]
  268. n_rates += n_iids * len(user_n_iid[n_iids])
  269. n_count -= n_iids * len(user_n_iid[n_iids])
  270. if n_rates >= count * 0.25 * (self.n_train + self.n_test):
  271. split_uids.append(temp)
  272. state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' %(n_iids, len(temp), n_rates)
  273. split_state.append(state)
  274. print(state)
  275. temp = []
  276. n_rates = 0
  277. fold -= 1
  278. if idx == len(user_n_iid.keys()) - 1 or n_count == 0:
  279. split_uids.append(temp)
  280. state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' % (n_iids, len(temp), n_rates)
  281. split_state.append(state)
  282. print(state)
  283. return split_uids, split_state