process_data.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """
  2. process the data to satisfy the lightgbm
  3. """
  4. import datetime
  5. import sys
  6. import os
  7. import json
  8. import asyncio
  9. import argparse
  10. import time
  11. import numpy as np
  12. from tqdm import tqdm
  13. import jieba.analyse
  14. import pandas as pd
  15. sys.path.append(os.getcwd())
  16. from functions import generate_label_date, generate_daily_strings, MysqlClient, MySQLClientSpider
  17. class DataProcessor(object):
  18. """
  19. Insert some information to lightgbm_data
  20. """
  21. def __init__(self, ):
  22. self.client = MysqlClient()
  23. self.client_spider = MySQLClientSpider()
  24. self.label_data = {}
  25. def producer(self):
  26. """
  27. 生成数据
  28. :return:none
  29. """
  30. # 把 label, video_title, daily_dt_str, 存储到 mysql 数据库中去
  31. label_path = "data/train_data/daily-label-20240326-20240331.json"
  32. with open(label_path, encoding="utf-8") as f:
  33. self.label_data = json.loads(f.read())
  34. def read_title(client, video_id):
  35. """
  36. read_title_from mysql
  37. """
  38. sql = f"""SELECT title from wx_video where id = {video_id};"""
  39. # print("title", sql)
  40. try:
  41. title = client.select(sql)[0][0]
  42. return title.strip()
  43. except Exception as e:
  44. print(video_id, "\t", e)
  45. return ""
  46. def generate_label(video_id, hourly_dt_str, label_info):
  47. """
  48. generate label daily_dt_str for mysql
  49. :param label_info:
  50. :param video_id:
  51. :param hourly_dt_str:
  52. :return: label, daily_dt_str
  53. """
  54. label_dt = generate_label_date(hourly_dt_str)
  55. label_obj = label_info.get(label_dt, {}).get(video_id)
  56. if label_obj:
  57. label = int(label_obj["total_return"]) if label_obj["total_return"] else 0
  58. # print(label)
  59. else:
  60. label = 0
  61. return label, label_dt
  62. def process_info(item_):
  63. """
  64. Insert data into MySql
  65. :param item_:
  66. """
  67. video_id, hour_dt = item_
  68. # print(type(video_id))
  69. label_info = self.label_data
  70. title = read_title(client=self.client, video_id=video_id)
  71. label, dt_daily = generate_label(str(video_id), hour_dt, label_info)
  72. insert_sql = f"""UPDATE lightgbm_data set video_title = '{title}', label = '{label}', daily_dt_str = '{dt_daily}' where video_id = '{video_id}';"""
  73. # print(insert_sql)
  74. self.client_spider.update(insert_sql)
  75. select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label is NULL and hour_dt_str >= '2024032700';"
  76. init_data_tuple = self.client_spider.select(select_sql)
  77. init_list = list(init_data_tuple)
  78. for item in tqdm(init_list):
  79. try:
  80. process_info(item)
  81. except Exception as e:
  82. print("操作失败", e)
  83. class SpiderProcess(object):
  84. """
  85. Spider Data Process and Process data for lightgbm training
  86. """
  87. def __init__(self):
  88. self.client_spider = MySQLClientSpider()
  89. self.spider_features = [
  90. "channel",
  91. "out_user_id",
  92. "mode",
  93. "out_play_cnt",
  94. "out_like_cnt",
  95. "out_share_cnt"
  96. ]
  97. def spider_lop(self, video_id):
  98. """
  99. Spider lop = like / play
  100. :param video_id:
  101. :return:
  102. """
  103. sql = f"""SELECT like_cnt, play_cnt, duration from crawler_video where video_id = '{video_id}';"""
  104. try:
  105. like_cnt, play_cnt, duration = self.client_spider.select(sql)[0]
  106. lop = (like_cnt + 700) / (play_cnt + 18000)
  107. return lop, duration
  108. except Exception as e:
  109. print(video_id, "\t", e)
  110. return 0, 0
  111. def spider_data_produce(self, flag, dt_time=None):
  112. """
  113. 把 spider_duration 存储到数据库中
  114. :return:
  115. """
  116. if flag == "train":
  117. select_sql = "SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' order by daily_dt_str;"
  118. des_path = "data/train_data/spider_train_{}".format(datetime.datetime.today().strftime("%Y%m%d"))
  119. elif flag == "predict":
  120. dt_time = datetime.datetime.strptime(dt_time, "%Y%m%d")
  121. three_date_before = dt_time + datetime.timedelta(days=4)
  122. temp_time = three_date_before.strftime("%Y%m%d")
  123. select_sql = f"""SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' and daily_dt_str = '{temp_time}';"""
  124. print(select_sql)
  125. des_path = "data/predict_data/predict_{}.json".format(dt_time.strftime("%Y%m%d"))
  126. else:
  127. return
  128. data_list = self.client_spider.select(select_sql)
  129. df = []
  130. for line in tqdm(data_list):
  131. try:
  132. temp = list(line)
  133. video_id = line[0]
  134. title = line[1]
  135. lop, duration = self.spider_lop(video_id)
  136. title_tags = list(jieba.analyse.textrank(title, topK=3))
  137. temp.append(lop)
  138. temp.append(duration)
  139. if title_tags:
  140. for i in range(3):
  141. try:
  142. temp.append(title_tags[i])
  143. except:
  144. temp.append(None)
  145. else:
  146. temp.append(None)
  147. temp.append(None)
  148. temp.append(None)
  149. df.append(temp[2:])
  150. except:
  151. continue
  152. df = pd.DataFrame(df, columns=['label', 'channel', 'out_user_id', 'mode', 'out_play_cnt', 'out_like_cnt',
  153. 'out_share_cnt', 'lop', 'duration', 'tag1', 'tag2', 'tag3'])
  154. df.to_json(des_path, orient='records')
  155. class UserProcess(object):
  156. """
  157. User Data Process
  158. """
  159. def __init__(self):
  160. self.client_spider = MySQLClientSpider()
  161. self.user_features = [
  162. "label",
  163. "uid",
  164. "channel",
  165. "user_fans",
  166. "user_view_30",
  167. "user_share_30",
  168. "user_return_30",
  169. "user_rov",
  170. "user_str",
  171. "user_return_videos_30",
  172. "user_return_videos_3",
  173. "user_return_3",
  174. "user_view_3",
  175. "user_share_3",
  176. "address",
  177. "tag1",
  178. "tag2",
  179. "tag3"
  180. ]
  181. def userinfo_to_mysql(self, start_date, end_date):
  182. """
  183. 把 user_return_3, user_view_3, user_share_3
  184. user_return_videos_3, user_return_videos_30
  185. address 存储到 mysql 数据库中
  186. :return:
  187. """
  188. user_path = 'data/train_data/daily-user-info-{}-{}.json'.format(start_date, end_date)
  189. with open(user_path) as f:
  190. data = json.loads(f.read())
  191. sql = "select video_id, hour_dt_str from lightgbm_data where type = 'userupload' and address is NULL;"
  192. dt_list = self.client_spider.select(sql)
  193. for item in tqdm(dt_list):
  194. video_id, dt = item
  195. dt = dt[:8]
  196. user_info_obj = data.get(dt, {}).get(str(video_id))
  197. if user_info_obj:
  198. try:
  199. video_id = user_info_obj['video_id']
  200. address = user_info_obj['address']
  201. return_3 = user_info_obj['return_3days']
  202. view_3 = user_info_obj['view_3days']
  203. share_3 = user_info_obj['share_3days']
  204. return_videos_3 = user_info_obj['3day_return_500_videos']
  205. return_videos_30 = user_info_obj['30day_return_2000_videos']
  206. update_sql = f"""UPDATE lightgbm_data set address='{address}', user_return_3={return_3}, user_view_3={view_3}, user_share_3={share_3}, user_return_videos_3={return_videos_3}, user_return_videos_30={return_videos_30} where video_id = '{video_id}';"""
  207. self.client_spider.update(update_sql)
  208. except Exception as e:
  209. print(e)
  210. pass
  211. else:
  212. print("No user info")
  213. def generate_user_data(self, flag, dt_time=None):
  214. """
  215. 生成user训练数据
  216. :return:
  217. """
  218. dt_time = datetime.datetime.strptime(dt_time, "%Y%m%d")
  219. three_date_before = dt_time + datetime.timedelta(days=4)
  220. temp_time = three_date_before.strftime("%Y%m%d")
  221. if flag == "train":
  222. sql = "select video_title, label, user_id, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lightgbm_data where type = 'userupload' and daily_dt_str >= '20240305';"
  223. des_path = "data/train_data/user_train_{}.json".format(datetime.datetime.today().strftime("%Y%m%d"))
  224. elif flag == "predict":
  225. sql = f"""select video_title, label, user_id, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lightgbm_data where type = 'userupload' and daily_dt_str = '{temp_time}';"""
  226. des_path = "data/predict_data/user_predict_{}.json".format(dt_time.strftime("%Y%m%d"))
  227. else:
  228. return
  229. dt_list = self.client_spider.select(sql)
  230. df = []
  231. for line in tqdm(dt_list):
  232. title = line[0]
  233. temp = list(line)
  234. title_tags = list(jieba.analyse.textrank(title, topK=3))
  235. if title_tags:
  236. for i in range(3):
  237. try:
  238. temp.append(title_tags[i])
  239. except:
  240. temp.append(None)
  241. else:
  242. temp.append(None)
  243. temp.append(None)
  244. temp.append(None)
  245. df.append(temp[1:])
  246. df = pd.DataFrame(df, columns=self.user_features)
  247. df['ros_30'] = np.where(df['user_view_30'] != 0, df['user_return_30'] / df['user_share_30'], np.nan)
  248. df['rov_30'] = np.where(df['user_view_30'] != 0, df['user_return_30'] / df['user_view_30'], np.nan)
  249. df['ros_3'] = np.where(df['user_view_3'] != 0, df['user_return_3'] / df['user_share_3'], np.nan)
  250. df['rov_3'] = np.where(df['user_view_3'] != 0, df['user_return_3'] / df['user_view_3'], np.nan)
  251. df.to_json(des_path, orient='records')
  252. if __name__ == "__main__":
  253. parser = argparse.ArgumentParser() # 新建参数解释器对象
  254. parser.add_argument("--mode")
  255. parser.add_argument("--de")
  256. parser.add_argument("--dt")
  257. args = parser.parse_args()
  258. mode = args.mode
  259. D = args.de
  260. dt = args.dt
  261. match D:
  262. case "spider":
  263. S = SpiderProcess()
  264. S.spider_data_produce(flag=mode, dt_time=dt)
  265. case "user":
  266. U = UserProcess()
  267. if mode == "generate":
  268. sd = str(input("输入开始日期,格式为 YYYYmmdd"))
  269. ed = str(input("输入结束日期,格式为 YYYYmmdd"))
  270. U.userinfo_to_mysql(start_date=sd, end_date=ed)
  271. else:
  272. U.generate_user_data(flag=mode, dt_time=dt)
  273. # else:
  274. # print("Error")
  275. case "Data":
  276. D = DataProcessor()
  277. D.producer()
  278. # if mode == "train":
  279. # print("Loading data and process for training.....")
  280. # D = DataProcessor(flag="train", ll=category)
  281. # D.producer("whole")
  282. # elif mode == "predict":
  283. # print("Loading data and process for prediction for each day......")
  284. # D = DataProcessor(flag="predict", ll=category)
  285. # if dtype == "single":
  286. # date_str = str(input("Please enter the date of the prediction"))
  287. # D.producer(date_str)
  288. # elif dtype == "days":
  289. # start_date_str = str(input("Please enter the start date of the prediction"))
  290. # end_date_str = str(input("Please enter the end date of the prediction"))
  291. # dt_list = generate_daily_strings(start_date=start_date_str, end_date=end_date_str)
  292. # for d in dt_list:
  293. # D.producer()