test.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import sys
  2. import os
  3. import json
  4. from tqdm import tqdm
  5. sys.path.append(os.getcwd())
  6. from functions.generate_data import generate_label_date
  7. def generate_train_label(item, y_ori_data):
  8. """
  9. 生成训练数据,用 np.array矩阵的方式返回,
  10. :return: x_train, 训练数据, y_train, 训练 label
  11. """
  12. video_id = item['video_id']
  13. dt = item['dt']
  14. userful_features = [
  15. "uid",
  16. "type",
  17. "channel",
  18. "fans",
  19. "view_count_user_30days",
  20. "share_count_user_30days",
  21. "return_count_user_30days",
  22. "rov_user",
  23. "str_user",
  24. "out_user_id",
  25. "mode",
  26. "out_play_cnt",
  27. "out_like_cnt",
  28. "out_share_cnt",
  29. "out_collection_cnt"
  30. ]
  31. item_features = [item[i] for i in userful_features]
  32. label_dt = generate_label_date(dt)
  33. label_obj = y_ori_data.get(label_dt, {}).get(video_id)
  34. if label_obj:
  35. label = int(label_obj['total_return']) if label_obj['total_return'] else 0
  36. else:
  37. label = 0
  38. return label, item_features
  39. if __name__ == '__main__':
  40. x_path = 'data/train_january.json'
  41. y_path = 'data/jan_feb_label.json'
  42. with open(x_path) as f:
  43. x_data = json.loads(f.read())
  44. with open(y_path) as f:
  45. y_data = json.loads(f.read())
  46. x_list = []
  47. y_list = []
  48. for video_obj in tqdm(x_data):
  49. print(video_obj)
  50. our_label, features = generate_train_label(video_obj, y_data)
  51. # if our_label:
  52. x_list.append(features)
  53. y_list.append(our_label)
  54. # print(len(y_list))
  55. with open("whole_data/x_data.json", "w") as f1:
  56. f1.write(json.dumps(x_list, ensure_ascii=False))
  57. with open("whole_data/y_data.json", "w") as f2:
  58. f2.write(json.dumps(y_list, ensure_ascii=False))