update_mysql_data.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. process the data to satisfy the lightgbm
  3. """
  4. import sys
  5. import os
  6. import json
  7. import argparse
  8. from tqdm import tqdm
  9. from concurrent.futures.thread import ThreadPoolExecutor
  10. sys.path.append(os.getcwd())
  11. from functions import generate_label_date, MysqlClient, MySQLClientSpider
  12. class DataProcessor(object):
  13. """
  14. Insert some information to lightgbm_data
  15. """
  16. def __init__(self, ):
  17. self.client = MysqlClient()
  18. self.client_spider = MySQLClientSpider()
  19. self.label_data = {}
  20. def update_label(self, start_date, end_date):
  21. """
  22. 生成数据
  23. :return:none
  24. """
  25. # 把 label, video_title, daily_dt_str, 存储到 mysql 数据库中去
  26. label_path = "/root/luojunhui/alg/data/train_data/daily-label-{}-{}.json".format(start_date, end_date)
  27. with open(label_path, encoding="utf-8") as f:
  28. self.label_data = json.loads(f.read())
  29. def read_title(client, video_id):
  30. """
  31. read_title_from mysql
  32. """
  33. sql = f"""SELECT title from wx_video where id = {video_id};"""
  34. # print("title", sql)
  35. try:
  36. title = client.select(sql)[0][0]
  37. return title.strip()
  38. except Exception as e:
  39. print(video_id, "\t", e)
  40. return ""
  41. def generate_label(video_id, hourly_dt_str, label_info):
  42. """
  43. generate label daily_dt_str for mysql
  44. :param label_info:
  45. :param video_id:
  46. :param hourly_dt_str:
  47. :return: label, daily_dt_str
  48. """
  49. label_dt = generate_label_date(hourly_dt_str)
  50. label_obj = label_info.get(label_dt, {}).get(video_id)
  51. if label_obj:
  52. total_return = label_obj.get('total_return', 0)
  53. total_view = label_obj.get('total_view', 0)
  54. if total_return is not None and total_view is not None:
  55. total_return = float(total_return)
  56. total_view = float(total_view)
  57. if total_view == 0:
  58. label = None
  59. else:
  60. if total_return == 0:
  61. label = None
  62. else:
  63. label = float(total_return) / float(total_view)
  64. elif total_return is None and total_view is not None:
  65. label = 0
  66. else:
  67. label = None
  68. else:
  69. label = None
  70. return label, label_dt
  71. def process_info(item_):
  72. """
  73. Insert data into MySql
  74. :param item_:
  75. """
  76. video_id, hour_dt = item_
  77. # print(type(video_id))
  78. label_info = self.label_data
  79. # title = read_title(client=self.client, video_id=video_id)
  80. label, dt_daily = generate_label(str(video_id), hour_dt, label_info)
  81. insert_sql = f"""UPDATE lightgbm_data set rov_label = '{label}', daily_dt_str = '{dt_daily}' where video_id = '{video_id}';"""
  82. # print(insert_sql)
  83. self.client_spider.update(insert_sql)
  84. select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where rov_label is NULL;"
  85. init_data_tuple = self.client_spider.select(select_sql)
  86. init_list = list(init_data_tuple)
  87. # with ThreadPoolExecutor(max_workers=4) as Pool:
  88. # Pool.map(process_info, init_list)
  89. for item in tqdm(init_list):
  90. try:
  91. process_info(item)
  92. except Exception as e:
  93. print("操作失败", e)
  94. def update_user_info(self, start_date, end_date):
  95. """
  96. 把补充的 user_info更新到 mysql 中、
  97. 把 user_return_3, user_view_3, user_share_3
  98. user_return_videos_3, user_return_videos_30
  99. address 存储到 mysql 数据库中
  100. :return:
  101. """
  102. user_path = '/root/luojunhui/alg/data/train_data/daily-user-info-{}-{}.json'.format(start_date, end_date)
  103. with open(user_path) as f:
  104. data = json.loads(f.read())
  105. sql = "select video_id, hour_dt_str from lightgbm_data where type = 'userupload' and address is NULL;"
  106. dt_list = self.client_spider.select(sql)
  107. for item in tqdm(dt_list):
  108. video_id, dt_info = item
  109. dt_info = dt_info[:8]
  110. user_info_obj = data.get(dt_info, {}).get(str(video_id))
  111. if user_info_obj:
  112. try:
  113. video_id = user_info_obj['video_id']
  114. address = user_info_obj['address']
  115. return_3 = user_info_obj['return_3days']
  116. view_3 = user_info_obj['view_3days']
  117. share_3 = user_info_obj['share_3days']
  118. return_videos_3 = user_info_obj['3day_return_500_videos']
  119. return_videos_30 = user_info_obj['30day_return_2000_videos']
  120. update_sql = f"""UPDATE lightgbm_data set address='{address}', user_return_3={return_3}, user_view_3={view_3}, user_share_3={share_3}, user_return_videos_3={return_videos_3}, user_return_videos_30={return_videos_30} where video_id = '{video_id}';"""
  121. self.client_spider.update(update_sql)
  122. except Exception as e:
  123. print(e)
  124. pass
  125. else:
  126. print("No user info")
  127. if __name__ == "__main__":
  128. parser = argparse.ArgumentParser() # 新建参数解释器对象
  129. parser.add_argument("--param")
  130. args = parser.parse_args()
  131. param = args.param
  132. D = DataProcessor()
  133. match param:
  134. case "label":
  135. sd = str(input("输入label日级表开始日期,格式为 YYYYmmdd"))
  136. ed = str(input("输入label日级表结束日期,格式为 YYYYmmdd"))
  137. D.update_label(start_date=sd, end_date=ed)
  138. case "user_info":
  139. sd = str(input("输入label日级表开始日期,格式为 YYYYmmdd"))
  140. ed = str(input("输入label日级表结束日期,格式为 YYYYmmdd"))
  141. D.update_user_info(start_date=sd, end_date=ed)
  142. # case "spider":
  143. # S = SpiderProcess()
  144. # S.spider_data_produce(flag=mode, dt_time=dt)
  145. # case "user":
  146. # U = UserProcess()
  147. # if mode == "generate":
  148. # sd = str(input("输入开始日期,格式为 YYYYmmdd"))
  149. # ed = str(input("输入结束日期,格式为 YYYYmmdd"))
  150. # U.userinfo_to_mysql(start_date=sd, end_date=ed)
  151. # else:
  152. # U.generate_user_data(flag=mode, dt_time=dt)