LightGCN.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. '''
  2. Created on Oct 10, 2018
  3. Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:
  4. Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.
  5. @author: Xiang Wang (xiangwang@u.nus.edu)
  6. version:
  7. Parallelized sampling on CPU
  8. C++ evaluation for top-k recommendation
  9. '''
  10. import os
  11. import sys
  12. import threading
  13. import tensorflow as tf
  14. from tensorflow.python.client import device_lib
  15. from utility.helper import *
  16. from utility.batch_test import *
  17. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
  18. cpus = [x.name for x in device_lib.list_local_devices() if x.device_type == 'CPU']
  19. class LightGCN(object):
  20. def __init__(self, data_config, pretrain_data):
  21. # argument settings
  22. self.model_type = 'LightGCN'
  23. self.adj_type = args.adj_type
  24. self.alg_type = args.alg_type
  25. self.pretrain_data = pretrain_data
  26. self.n_users = data_config['n_users']
  27. self.n_items = data_config['n_items']
  28. self.n_fold = 100
  29. self.norm_adj = data_config['norm_adj']
  30. self.n_nonzero_elems = self.norm_adj.count_nonzero()
  31. self.lr = args.lr
  32. self.emb_dim = args.embed_size
  33. self.batch_size = args.batch_size
  34. self.weight_size = eval(args.layer_size)
  35. self.n_layers = len(self.weight_size)
  36. self.regs = eval(args.regs)
  37. self.decay = self.regs[0]
  38. self.log_dir=self.create_model_str()
  39. self.verbose = args.verbose
  40. self.Ks = eval(args.Ks)
  41. '''
  42. *********************************************************
  43. Create Placeholder for Input Data & Dropout.
  44. '''
  45. # placeholder definition
  46. self.users = tf.placeholder(tf.int32, shape=(None,))
  47. self.pos_items = tf.placeholder(tf.int32, shape=(None,))
  48. self.neg_items = tf.placeholder(tf.int32, shape=(None,))
  49. self.node_dropout_flag = args.node_dropout_flag
  50. self.node_dropout = tf.placeholder(tf.float32, shape=[None])
  51. self.mess_dropout = tf.placeholder(tf.float32, shape=[None])
  52. with tf.name_scope('TRAIN_LOSS'):
  53. self.train_loss = tf.placeholder(tf.float32)
  54. tf.summary.scalar('train_loss', self.train_loss)
  55. self.train_mf_loss = tf.placeholder(tf.float32)
  56. tf.summary.scalar('train_mf_loss', self.train_mf_loss)
  57. self.train_emb_loss = tf.placeholder(tf.float32)
  58. tf.summary.scalar('train_emb_loss', self.train_emb_loss)
  59. self.train_reg_loss = tf.placeholder(tf.float32)
  60. tf.summary.scalar('train_reg_loss', self.train_reg_loss)
  61. self.merged_train_loss = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TRAIN_LOSS'))
  62. with tf.name_scope('TRAIN_ACC'):
  63. self.train_rec_first = tf.placeholder(tf.float32)
  64. #record for top(Ks[0])
  65. tf.summary.scalar('train_rec_first', self.train_rec_first)
  66. self.train_rec_last = tf.placeholder(tf.float32)
  67. #record for top(Ks[-1])
  68. tf.summary.scalar('train_rec_last', self.train_rec_last)
  69. self.train_ndcg_first = tf.placeholder(tf.float32)
  70. tf.summary.scalar('train_ndcg_first', self.train_ndcg_first)
  71. self.train_ndcg_last = tf.placeholder(tf.float32)
  72. tf.summary.scalar('train_ndcg_last', self.train_ndcg_last)
  73. self.merged_train_acc = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TRAIN_ACC'))
  74. with tf.name_scope('TEST_LOSS'):
  75. self.test_loss = tf.placeholder(tf.float32)
  76. tf.summary.scalar('test_loss', self.test_loss)
  77. self.test_mf_loss = tf.placeholder(tf.float32)
  78. tf.summary.scalar('test_mf_loss', self.test_mf_loss)
  79. self.test_emb_loss = tf.placeholder(tf.float32)
  80. tf.summary.scalar('test_emb_loss', self.test_emb_loss)
  81. self.test_reg_loss = tf.placeholder(tf.float32)
  82. tf.summary.scalar('test_reg_loss', self.test_reg_loss)
  83. self.merged_test_loss = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TEST_LOSS'))
  84. with tf.name_scope('TEST_ACC'):
  85. self.test_rec_first = tf.placeholder(tf.float32)
  86. tf.summary.scalar('test_rec_first', self.test_rec_first)
  87. self.test_rec_last = tf.placeholder(tf.float32)
  88. tf.summary.scalar('test_rec_last', self.test_rec_last)
  89. self.test_ndcg_first = tf.placeholder(tf.float32)
  90. tf.summary.scalar('test_ndcg_first', self.test_ndcg_first)
  91. self.test_ndcg_last = tf.placeholder(tf.float32)
  92. tf.summary.scalar('test_ndcg_last', self.test_ndcg_last)
  93. self.merged_test_acc = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TEST_ACC'))
  94. """
  95. *********************************************************
  96. Create Model Parameters (i.e., Initialize Weights).
  97. """
  98. # initialization of model parameters
  99. self.weights = self._init_weights()
  100. """
  101. *********************************************************
  102. Compute Graph-based Representations of all users & items via Message-Passing Mechanism of Graph Neural Networks.
  103. Different Convolutional Layers:
  104. 1. ngcf: defined in 'Neural Graph Collaborative Filtering', SIGIR2019;
  105. 2. gcn: defined in 'Semi-Supervised Classification with Graph Convolutional Networks', ICLR2018;
  106. 3. gcmc: defined in 'Graph Convolutional Matrix Completion', KDD2018;
  107. """
  108. if self.alg_type in ['lightgcn']:
  109. self.ua_embeddings, self.ia_embeddings = self._create_lightgcn_embed()
  110. elif self.alg_type in ['ngcf']:
  111. self.ua_embeddings, self.ia_embeddings = self._create_ngcf_embed()
  112. elif self.alg_type in ['gcn']:
  113. self.ua_embeddings, self.ia_embeddings = self._create_gcn_embed()
  114. elif self.alg_type in ['gcmc']:
  115. self.ua_embeddings, self.ia_embeddings = self._create_gcmc_embed()
  116. """
  117. *********************************************************
  118. Establish the final representations for user-item pairs in batch.
  119. """
  120. self.u_g_embeddings = tf.nn.embedding_lookup(self.ua_embeddings, self.users)
  121. self.pos_i_g_embeddings = tf.nn.embedding_lookup(self.ia_embeddings, self.pos_items)
  122. self.neg_i_g_embeddings = tf.nn.embedding_lookup(self.ia_embeddings, self.neg_items)
  123. self.u_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['user_embedding'], self.users)
  124. self.pos_i_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['item_embedding'], self.pos_items)
  125. self.neg_i_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['item_embedding'], self.neg_items)
  126. """
  127. *********************************************************
  128. Inference for the testing phase.
  129. """
  130. self.batch_ratings = tf.matmul(self.u_g_embeddings, self.pos_i_g_embeddings, transpose_a=False, transpose_b=True)
  131. """
  132. *********************************************************
  133. Generate Predictions & Optimize via BPR loss.
  134. """
  135. self.mf_loss, self.emb_loss, self.reg_loss = self.create_bpr_loss(self.u_g_embeddings,
  136. self.pos_i_g_embeddings,
  137. self.neg_i_g_embeddings)
  138. self.loss = self.mf_loss + self.emb_loss
  139. self.opt = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
  140. def create_model_str(self):
  141. log_dir = '/' + self.alg_type+'/layers_'+str(self.n_layers)+'/dim_'+str(self.emb_dim)
  142. log_dir+='/'+args.dataset+'/lr_' + str(self.lr) + '/reg_' + str(self.decay)
  143. return log_dir
  144. def _init_weights(self):
  145. all_weights = dict()
  146. initializer = tf.random_normal_initializer(stddev=0.01) #tf.contrib.layers.xavier_initializer()
  147. if self.pretrain_data is None:
  148. all_weights['user_embedding'] = tf.Variable(initializer([self.n_users, self.emb_dim]), name='user_embedding')
  149. all_weights['item_embedding'] = tf.Variable(initializer([self.n_items, self.emb_dim]), name='item_embedding')
  150. print('using random initialization')#print('using xavier initialization')
  151. else:
  152. all_weights['user_embedding'] = tf.Variable(initial_value=self.pretrain_data['user_embed'], trainable=True,
  153. name='user_embedding', dtype=tf.float32)
  154. all_weights['item_embedding'] = tf.Variable(initial_value=self.pretrain_data['item_embed'], trainable=True,
  155. name='item_embedding', dtype=tf.float32)
  156. print('using pretrained initialization')
  157. self.weight_size_list = [self.emb_dim] + self.weight_size
  158. for k in range(self.n_layers):
  159. all_weights['W_gc_%d' %k] = tf.Variable(
  160. initializer([self.weight_size_list[k], self.weight_size_list[k+1]]), name='W_gc_%d' % k)
  161. all_weights['b_gc_%d' %k] = tf.Variable(
  162. initializer([1, self.weight_size_list[k+1]]), name='b_gc_%d' % k)
  163. all_weights['W_bi_%d' % k] = tf.Variable(
  164. initializer([self.weight_size_list[k], self.weight_size_list[k + 1]]), name='W_bi_%d' % k)
  165. all_weights['b_bi_%d' % k] = tf.Variable(
  166. initializer([1, self.weight_size_list[k + 1]]), name='b_bi_%d' % k)
  167. all_weights['W_mlp_%d' % k] = tf.Variable(
  168. initializer([self.weight_size_list[k], self.weight_size_list[k+1]]), name='W_mlp_%d' % k)
  169. all_weights['b_mlp_%d' % k] = tf.Variable(
  170. initializer([1, self.weight_size_list[k+1]]), name='b_mlp_%d' % k)
  171. return all_weights
  172. def _split_A_hat(self, X):
  173. A_fold_hat = []
  174. fold_len = (self.n_users + self.n_items) // self.n_fold
  175. for i_fold in range(self.n_fold):
  176. start = i_fold * fold_len
  177. if i_fold == self.n_fold -1:
  178. end = self.n_users + self.n_items
  179. else:
  180. end = (i_fold + 1) * fold_len
  181. A_fold_hat.append(self._convert_sp_mat_to_sp_tensor(X[start:end]))
  182. return A_fold_hat
  183. def _split_A_hat_node_dropout(self, X):
  184. A_fold_hat = []
  185. fold_len = (self.n_users + self.n_items) // self.n_fold
  186. for i_fold in range(self.n_fold):
  187. start = i_fold * fold_len
  188. if i_fold == self.n_fold -1:
  189. end = self.n_users + self.n_items
  190. else:
  191. end = (i_fold + 1) * fold_len
  192. temp = self._convert_sp_mat_to_sp_tensor(X[start:end])
  193. n_nonzero_temp = X[start:end].count_nonzero()
  194. A_fold_hat.append(self._dropout_sparse(temp, 1 - self.node_dropout[0], n_nonzero_temp))
  195. return A_fold_hat
  196. def _create_lightgcn_embed(self):
  197. if self.node_dropout_flag:
  198. A_fold_hat = self._split_A_hat_node_dropout(self.norm_adj)
  199. else:
  200. A_fold_hat = self._split_A_hat(self.norm_adj)
  201. ego_embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0)
  202. all_embeddings = [ego_embeddings]
  203. for k in range(0, self.n_layers):
  204. temp_embed = []
  205. for f in range(self.n_fold):
  206. temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], ego_embeddings))
  207. side_embeddings = tf.concat(temp_embed, 0)
  208. ego_embeddings = side_embeddings
  209. all_embeddings += [ego_embeddings]
  210. all_embeddings=tf.stack(all_embeddings,1)
  211. all_embeddings=tf.reduce_mean(all_embeddings,axis=1,keepdims=False)
  212. u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0)
  213. return u_g_embeddings, i_g_embeddings
  214. def _create_ngcf_embed(self):
  215. if self.node_dropout_flag:
  216. A_fold_hat = self._split_A_hat_node_dropout(self.norm_adj)
  217. else:
  218. A_fold_hat = self._split_A_hat(self.norm_adj)
  219. ego_embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0)
  220. all_embeddings = [ego_embeddings]
  221. for k in range(0, self.n_layers):
  222. temp_embed = []
  223. for f in range(self.n_fold):
  224. temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], ego_embeddings))
  225. side_embeddings = tf.concat(temp_embed, 0)
  226. sum_embeddings = tf.nn.leaky_relu(tf.matmul(side_embeddings, self.weights['W_gc_%d' % k]) + self.weights['b_gc_%d' % k])
  227. # bi messages of neighbors.
  228. bi_embeddings = tf.multiply(ego_embeddings, side_embeddings)
  229. # transformed bi messages of neighbors.
  230. bi_embeddings = tf.nn.leaky_relu(tf.matmul(bi_embeddings, self.weights['W_bi_%d' % k]) + self.weights['b_bi_%d' % k])
  231. # non-linear activation.
  232. ego_embeddings = sum_embeddings + bi_embeddings
  233. # message dropout.
  234. # ego_embeddings = tf.nn.dropout(ego_embeddings, 1 - self.mess_dropout[k])
  235. # normalize the distribution of embeddings.
  236. norm_embeddings = tf.nn.l2_normalize(ego_embeddings, axis=1)
  237. all_embeddings += [norm_embeddings]
  238. all_embeddings = tf.concat(all_embeddings, 1)
  239. u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0)
  240. return u_g_embeddings, i_g_embeddings
  241. def _create_gcn_embed(self):
  242. A_fold_hat = self._split_A_hat(self.norm_adj)
  243. embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0)
  244. all_embeddings = [embeddings]
  245. for k in range(0, self.n_layers):
  246. temp_embed = []
  247. for f in range(self.n_fold):
  248. temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], embeddings))
  249. embeddings = tf.concat(temp_embed, 0)
  250. embeddings = tf.nn.leaky_relu(tf.matmul(embeddings, self.weights['W_gc_%d' %k]) + self.weights['b_gc_%d' %k])
  251. # embeddings = tf.nn.dropout(embeddings, 1 - self.mess_dropout[k])
  252. all_embeddings += [embeddings]
  253. all_embeddings = tf.concat(all_embeddings, 1)
  254. u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0)
  255. return u_g_embeddings, i_g_embeddings
  256. def _create_gcmc_embed(self):
  257. A_fold_hat = self._split_A_hat(self.norm_adj)
  258. embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0)
  259. all_embeddings = []
  260. for k in range(0, self.n_layers):
  261. temp_embed = []
  262. for f in range(self.n_fold):
  263. temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], embeddings))
  264. embeddings = tf.concat(temp_embed, 0)
  265. # convolutional layer.
  266. embeddings = tf.nn.leaky_relu(tf.matmul(embeddings, self.weights['W_gc_%d' % k]) + self.weights['b_gc_%d' % k])
  267. # dense layer.
  268. mlp_embeddings = tf.matmul(embeddings, self.weights['W_mlp_%d' %k]) + self.weights['b_mlp_%d' %k]
  269. # mlp_embeddings = tf.nn.dropout(mlp_embeddings, 1 - self.mess_dropout[k])
  270. all_embeddings += [mlp_embeddings]
  271. all_embeddings = tf.concat(all_embeddings, 1)
  272. u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0)
  273. return u_g_embeddings, i_g_embeddings
  274. def create_bpr_loss(self, users, pos_items, neg_items):
  275. pos_scores = tf.reduce_sum(tf.multiply(users, pos_items), axis=1)
  276. neg_scores = tf.reduce_sum(tf.multiply(users, neg_items), axis=1)
  277. regularizer = tf.nn.l2_loss(self.u_g_embeddings_pre) + tf.nn.l2_loss(
  278. self.pos_i_g_embeddings_pre) + tf.nn.l2_loss(self.neg_i_g_embeddings_pre)
  279. regularizer = regularizer / self.batch_size
  280. mf_loss = tf.reduce_mean(tf.nn.softplus(-(pos_scores - neg_scores)))
  281. emb_loss = self.decay * regularizer
  282. reg_loss = tf.constant(0.0, tf.float32, [1])
  283. return mf_loss, emb_loss, reg_loss
  284. def _convert_sp_mat_to_sp_tensor(self, X):
  285. coo = X.tocoo().astype(np.float32)
  286. indices = np.mat([coo.row, coo.col]).transpose()
  287. return tf.SparseTensor(indices, coo.data, coo.shape)
  288. def _dropout_sparse(self, X, keep_prob, n_nonzero_elems):
  289. """
  290. Dropout for sparse tensors.
  291. """
  292. noise_shape = [n_nonzero_elems]
  293. random_tensor = keep_prob
  294. random_tensor += tf.random_uniform(noise_shape)
  295. dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
  296. pre_out = tf.sparse_retain(X, dropout_mask)
  297. return pre_out * tf.div(1., keep_prob)
  298. def load_pretrained_data():
  299. pretrain_path = '%spretrain/%s/%s.npz' % (args.proj_path, args.dataset, 'embedding')
  300. try:
  301. pretrain_data = np.load(pretrain_path)
  302. print('load the pretrained embeddings.')
  303. except Exception:
  304. pretrain_data = None
  305. return pretrain_data
  306. # parallelized sampling on CPU
  307. class sample_thread(threading.Thread):
  308. def __init__(self):
  309. threading.Thread.__init__(self)
  310. def run(self):
  311. with tf.device(cpus[0]):
  312. self.data = data_generator.sample()
  313. class sample_thread_test(threading.Thread):
  314. def __init__(self):
  315. threading.Thread.__init__(self)
  316. def run(self):
  317. with tf.device(cpus[0]):
  318. self.data = data_generator.sample_test()
  319. # training on GPU
  320. class train_thread(threading.Thread):
  321. def __init__(self,model, sess, sample):
  322. threading.Thread.__init__(self)
  323. self.model = model
  324. self.sess = sess
  325. self.sample = sample
  326. def run(self):
  327. users, pos_items, neg_items = self.sample.data
  328. self.data = sess.run([self.model.opt, self.model.loss, self.model.mf_loss, self.model.emb_loss, self.model.reg_loss],
  329. feed_dict={model.users: users, model.pos_items: pos_items,
  330. model.node_dropout: eval(args.node_dropout),
  331. model.mess_dropout: eval(args.mess_dropout),
  332. model.neg_items: neg_items})
  333. class train_thread_test(threading.Thread):
  334. def __init__(self,model, sess, sample):
  335. threading.Thread.__init__(self)
  336. self.model = model
  337. self.sess = sess
  338. self.sample = sample
  339. def run(self):
  340. users, pos_items, neg_items = self.sample.data
  341. self.data = sess.run([self.model.loss, self.model.mf_loss, self.model.emb_loss],
  342. feed_dict={model.users: users, model.pos_items: pos_items,
  343. model.neg_items: neg_items,
  344. model.node_dropout: eval(args.node_dropout),
  345. model.mess_dropout: eval(args.mess_dropout)})
  346. if __name__ == '__main__':
  347. #os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
  348. f0 = time()
  349. config = dict()
  350. config['n_users'] = data_generator.n_users
  351. config['n_items'] = data_generator.n_items
  352. """
  353. *********************************************************
  354. Generate the Laplacian matrix, where each entry defines the decay factor (e.g., p_ui) between two connected nodes.
  355. """
  356. plain_adj, norm_adj, mean_adj,pre_adj = data_generator.get_adj_mat()
  357. if args.adj_type == 'plain':
  358. config['norm_adj'] = plain_adj
  359. print('use the plain adjacency matrix')
  360. elif args.adj_type == 'norm':
  361. config['norm_adj'] = norm_adj
  362. print('use the normalized adjacency matrix')
  363. elif args.adj_type == 'gcmc':
  364. config['norm_adj'] = mean_adj
  365. print('use the gcmc adjacency matrix')
  366. elif args.adj_type=='pre':
  367. config['norm_adj']=pre_adj
  368. print('use the pre adjcency matrix')
  369. else:
  370. config['norm_adj'] = mean_adj + sp.eye(mean_adj.shape[0])
  371. print('use the mean adjacency matrix')
  372. t0 = time()
  373. if args.pretrain == -1:
  374. pretrain_data = load_pretrained_data()
  375. else:
  376. pretrain_data = None
  377. model = LightGCN(data_config=config, pretrain_data=pretrain_data)
  378. """
  379. *********************************************************
  380. Save the model parameters.
  381. """
  382. saver = tf.train.Saver()
  383. if args.save_flag == 1:
  384. layer = '-'.join([str(l) for l in eval(args.layer_size)])
  385. weights_save_path = '%sweights/%s/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type, layer,
  386. str(args.lr), '-'.join([str(r) for r in eval(args.regs)]))
  387. ensureDir(weights_save_path)
  388. save_saver = tf.train.Saver(max_to_keep=1)
  389. config = tf.ConfigProto()
  390. #config.gpu_options.allow_growth = True
  391. sess = tf.Session(config=config)
  392. """
  393. *********************************************************
  394. Reload the pretrained model parameters.
  395. """
  396. if args.pretrain == 1:
  397. layer = '-'.join([str(l) for l in eval(args.layer_size)])
  398. pretrain_path = '%sweights/%s/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type, layer,
  399. str(args.lr), '-'.join([str(r) for r in eval(args.regs)]))
  400. ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrain_path + '/checkpoint'))
  401. if ckpt and ckpt.model_checkpoint_path:
  402. sess.run(tf.global_variables_initializer())
  403. saver.restore(sess, ckpt.model_checkpoint_path)
  404. print('load the pretrained model parameters from: ', pretrain_path)
  405. # *********************************************************
  406. # get the performance from pretrained model.
  407. if args.report != 1:
  408. users_to_test = list(data_generator.test_set.keys())
  409. ret = test(sess, model, users_to_test, drop_flag=True)
  410. cur_best_pre_0 = ret['recall'][0]
  411. pretrain_ret = 'pretrained model recall=[%s], precision=[%s], '\
  412. 'ndcg=[%s]' % \
  413. (', '.join(['%.5f' % r for r in ret['recall']]),
  414. ', '.join(['%.5f' % r for r in ret['precision']]),
  415. ', '.join(['%.5f' % r for r in ret['ndcg']]))
  416. print(pretrain_ret)
  417. else:
  418. sess.run(tf.global_variables_initializer())
  419. cur_best_pre_0 = 0.
  420. print('without pretraining.')
  421. else:
  422. sess.run(tf.global_variables_initializer())
  423. cur_best_pre_0 = 0.
  424. print('without pretraining.')
  425. """
  426. *********************************************************
  427. Get the performance w.r.t. different sparsity levels.
  428. """
  429. if args.report == 1:
  430. assert args.test_flag == 'full'
  431. users_to_test_list, split_state = data_generator.get_sparsity_split()
  432. users_to_test_list.append(list(data_generator.test_set.keys()))
  433. split_state.append('all')
  434. report_path = '%sreport/%s/%s.result' % (args.proj_path, args.dataset, model.model_type)
  435. ensureDir(report_path)
  436. f = open(report_path, 'w')
  437. f.write(
  438. 'embed_size=%d, lr=%.4f, layer_size=%s, keep_prob=%s, regs=%s, loss_type=%s, adj_type=%s\n'
  439. % (args.embed_size, args.lr, args.layer_size, args.keep_prob, args.regs, args.loss_type, args.adj_type))
  440. for i, users_to_test in enumerate(users_to_test_list):
  441. ret = test(sess, model, users_to_test, drop_flag=True)
  442. final_perf = "recall=[%s], precision=[%s], ndcg=[%s]" % \
  443. (', '.join(['%.5f' % r for r in ret['recall']]),
  444. ', '.join(['%.5f' % r for r in ret['precision']]),
  445. ', '.join(['%.5f' % r for r in ret['ndcg']]))
  446. f.write('\t%s\n\t%s\n' % (split_state[i], final_perf))
  447. f.close()
  448. exit()
  449. """
  450. *********************************************************
  451. Train.
  452. """
  453. tensorboard_model_path = 'tensorboard/'
  454. if not os.path.exists(tensorboard_model_path):
  455. os.makedirs(tensorboard_model_path)
  456. run_time = 1
  457. while (True):
  458. if os.path.exists(tensorboard_model_path + model.log_dir +'/run_' + str(run_time)):
  459. run_time += 1
  460. else:
  461. break
  462. train_writer = tf.summary.FileWriter(tensorboard_model_path +model.log_dir+ '/run_' + str(run_time), sess.graph)
  463. loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
  464. stopping_step = 0
  465. should_stop = False
  466. for epoch in range(1, args.epoch + 1):
  467. t1 = time()
  468. loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0.
  469. n_batch = data_generator.n_train // args.batch_size + 1
  470. loss_test,mf_loss_test,emb_loss_test,reg_loss_test=0.,0.,0.,0.
  471. '''
  472. *********************************************************
  473. parallelized sampling
  474. '''
  475. sample_last = sample_thread()
  476. sample_last.start()
  477. sample_last.join()
  478. for idx in range(n_batch):
  479. train_cur = train_thread(model, sess, sample_last)
  480. sample_next = sample_thread()
  481. train_cur.start()
  482. sample_next.start()
  483. sample_next.join()
  484. train_cur.join()
  485. users, pos_items, neg_items = sample_last.data
  486. _, batch_loss, batch_mf_loss, batch_emb_loss, batch_reg_loss = train_cur.data
  487. sample_last = sample_next
  488. loss += batch_loss/n_batch
  489. mf_loss += batch_mf_loss/n_batch
  490. emb_loss += batch_emb_loss/n_batch
  491. summary_train_loss= sess.run(model.merged_train_loss,
  492. feed_dict={model.train_loss: loss, model.train_mf_loss: mf_loss,
  493. model.train_emb_loss: emb_loss, model.train_reg_loss: reg_loss})
  494. train_writer.add_summary(summary_train_loss, epoch)
  495. if np.isnan(loss) == True:
  496. print('ERROR: loss is nan.')
  497. sys.exit()
  498. if (epoch % 20) != 0:
  499. if args.verbose > 0 and epoch % args.verbose == 0:
  500. perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % (
  501. epoch, time() - t1, loss, mf_loss, emb_loss)
  502. print(perf_str)
  503. continue
  504. users_to_test = list(data_generator.train_items.keys())
  505. ret = test(sess, model, users_to_test ,drop_flag=True,train_set_flag=1)
  506. perf_str = 'Epoch %d: train==[%.5f=%.5f + %.5f + %.5f], recall=[%s], precision=[%s], ndcg=[%s]' % \
  507. (epoch, loss, mf_loss, emb_loss, reg_loss,
  508. ', '.join(['%.5f' % r for r in ret['recall']]),
  509. ', '.join(['%.5f' % r for r in ret['precision']]),
  510. ', '.join(['%.5f' % r for r in ret['ndcg']]))
  511. print(perf_str)
  512. summary_train_acc = sess.run(model.merged_train_acc, feed_dict={model.train_rec_first: ret['recall'][0],
  513. model.train_rec_last: ret['recall'][-1],
  514. model.train_ndcg_first: ret['ndcg'][0],
  515. model.train_ndcg_last: ret['ndcg'][-1]})
  516. train_writer.add_summary(summary_train_acc, epoch // 20)
  517. '''
  518. *********************************************************
  519. parallelized sampling
  520. '''
  521. sample_last= sample_thread_test()
  522. sample_last.start()
  523. sample_last.join()
  524. for idx in range(n_batch):
  525. train_cur = train_thread_test(model, sess, sample_last)
  526. sample_next = sample_thread_test()
  527. train_cur.start()
  528. sample_next.start()
  529. sample_next.join()
  530. train_cur.join()
  531. users, pos_items, neg_items = sample_last.data
  532. batch_loss_test, batch_mf_loss_test, batch_emb_loss_test = train_cur.data
  533. sample_last = sample_next
  534. loss_test += batch_loss_test / n_batch
  535. mf_loss_test += batch_mf_loss_test / n_batch
  536. emb_loss_test += batch_emb_loss_test / n_batch
  537. summary_test_loss = sess.run(model.merged_test_loss,
  538. feed_dict={model.test_loss: loss_test, model.test_mf_loss: mf_loss_test,
  539. model.test_emb_loss: emb_loss_test, model.test_reg_loss: reg_loss_test})
  540. train_writer.add_summary(summary_test_loss, epoch // 20)
  541. t2 = time()
  542. users_to_test = list(data_generator.test_set.keys())
  543. ret = test(sess, model, users_to_test, drop_flag=True)
  544. summary_test_acc = sess.run(model.merged_test_acc,
  545. feed_dict={model.test_rec_first: ret['recall'][0], model.test_rec_last: ret['recall'][-1],
  546. model.test_ndcg_first: ret['ndcg'][0], model.test_ndcg_last: ret['ndcg'][-1]})
  547. train_writer.add_summary(summary_test_acc, epoch // 20)
  548. t3 = time()
  549. loss_loger.append(loss)
  550. rec_loger.append(ret['recall'])
  551. pre_loger.append(ret['precision'])
  552. ndcg_loger.append(ret['ndcg'])
  553. if args.verbose > 0:
  554. perf_str = 'Epoch %d [%.1fs + %.1fs]: test==[%.5f=%.5f + %.5f + %.5f], recall=[%s], ' \
  555. 'precision=[%s], ndcg=[%s]' % \
  556. (epoch, t2 - t1, t3 - t2, loss_test, mf_loss_test, emb_loss_test, reg_loss_test,
  557. ', '.join(['%.5f' % r for r in ret['recall']]),
  558. ', '.join(['%.5f' % r for r in ret['precision']]),
  559. ', '.join(['%.5f' % r for r in ret['ndcg']]))
  560. print(perf_str)
  561. cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0,
  562. stopping_step, expected_order='acc', flag_step=5)
  563. # *********************************************************
  564. # early stopping when cur_best_pre_0 is decreasing for ten successive steps.
  565. if should_stop == True:
  566. break
  567. # *********************************************************
  568. # save the user & item embeddings for pretraining.
  569. if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1:
  570. save_saver.save(sess, weights_save_path + '/weights', global_step=epoch)
  571. print('save the weights in path: ', weights_save_path)
  572. recs = np.array(rec_loger)
  573. pres = np.array(pre_loger)
  574. ndcgs = np.array(ndcg_loger)
  575. best_rec_0 = max(recs[:, 0])
  576. idx = list(recs[:, 0]).index(best_rec_0)
  577. final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], ndcg=[%s]" % \
  578. (idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]),
  579. '\t'.join(['%.5f' % r for r in pres[idx]]),
  580. '\t'.join(['%.5f' % r for r in ndcgs[idx]]))
  581. print(final_perf)
  582. save_path = '%soutput/%s/%s.result' % (args.proj_path, args.dataset, model.model_type)
  583. ensureDir(save_path)
  584. f = open(save_path, 'a')
  585. f.write(
  586. 'embed_size=%d, lr=%.4f, layer_size=%s, node_dropout=%s, mess_dropout=%s, regs=%s, adj_type=%s\n\t%s\n'
  587. % (args.embed_size, args.lr, args.layer_size, args.node_dropout, args.mess_dropout, args.regs,
  588. args.adj_type, final_perf))
  589. f.close()