process_data.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """
  2. process the data to satisfy the lightgbm
  3. """
  4. import sys
  5. import os
  6. import json
  7. import asyncio
  8. import argparse
  9. from tqdm import tqdm
  10. import jieba.analyse
  11. from concurrent.futures.thread import ThreadPoolExecutor
  12. sys.path.append(os.getcwd())
  13. from functions import generate_label_date, generate_daily_strings, MysqlClient, MySQLClientSpider
  14. class DataProcessor(object):
  15. """
  16. Insert some information to lightgbm_data
  17. """
  18. def __init__(self, ):
  19. self.client = MysqlClient()
  20. self.client_spider = MySQLClientSpider()
  21. self.label_data = {}
  22. def generate_train_label(self, item, y_ori_data, cate):
  23. """
  24. 生成训练数据,用 np.array矩阵的方式返回,
  25. :return: x_train, 训练数据, y_train, 训练 label
  26. """
  27. video_id = item["video_id"]
  28. dt = item["dt"]
  29. useful_features = [
  30. "uid",
  31. "type",
  32. "channel",
  33. "fans",
  34. "view_count_user_30days",
  35. "share_count_user_30days",
  36. "return_count_user_30days",
  37. "rov_user",
  38. "str_user",
  39. "out_user_id",
  40. "mode",
  41. "out_play_cnt",
  42. "out_like_cnt",
  43. "out_share_cnt",
  44. "out_collection_cnt",
  45. ]
  46. spider_features = [
  47. "channel",
  48. "out_user_id",
  49. "mode",
  50. "out_play_cnt",
  51. "out_like_cnt",
  52. "out_share_cnt"
  53. ]
  54. user_features = [
  55. "uid",
  56. "channel",
  57. "fans",
  58. "view_count_user_30days",
  59. "share_count_user_30days",
  60. "return_count_user_30days",
  61. "rov_user",
  62. "str_user"
  63. ]
  64. match self.ll:
  65. case "all":
  66. item_features = [item[i] for i in useful_features]
  67. case "user":
  68. if item['type'] == "userupload":
  69. item_features = [item[i] for i in user_features]
  70. else:
  71. return None, None
  72. case "spider":
  73. if item['type'] == "spider":
  74. item_features = [item[i] for i in spider_features]
  75. lop, duration = self.cal_lop(video_id)
  76. item_features.append(lop)
  77. item_features.append(duration)
  78. else:
  79. return None, None
  80. keywords_textrank = self.title_processor(video_id)
  81. if keywords_textrank:
  82. for i in range(3):
  83. try:
  84. item_features.append(keywords_textrank[i])
  85. except:
  86. item_features.append(None)
  87. else:
  88. item_features.append(None)
  89. item_features.append(None)
  90. item_features.append(None)
  91. label_dt = generate_label_date(dt)
  92. label_obj = y_ori_data.get(label_dt, {}).get(video_id)
  93. if label_obj:
  94. label = int(label_obj[cate]) if label_obj[cate] else 0
  95. else:
  96. label = 0
  97. return label, item_features
  98. def producer(self):
  99. """
  100. 生成数据
  101. :return:none
  102. """
  103. # 把 label, video_title, daily_dt_str, 存储到 mysql 数据库中去
  104. label_path = "data/train_data/daily-label-20240101-20240325.json"
  105. with open(label_path, encoding="utf-8") as f:
  106. self.label_data = json.loads(f.read())
  107. def read_title(client, video_id):
  108. """
  109. read_title_from mysql
  110. """
  111. sql = f"""SELECT title from wx_video where id = {video_id};"""
  112. # print("title", sql)
  113. try:
  114. title = client.select(sql)[0][0]
  115. return title
  116. except Exception as e:
  117. print(video_id, "\t", e)
  118. return ""
  119. def generate_label(video_id, hourly_dt_str, label_info):
  120. """
  121. generate label daily_dt_str for mysql
  122. :param label_info:
  123. :param video_id:
  124. :param hourly_dt_str:
  125. :return: label, daily_dt_str
  126. """
  127. label_dt = generate_label_date(hourly_dt_str)
  128. label_obj = label_info.get(label_dt, {}).get(video_id)
  129. if label_obj:
  130. label = int(label_obj["total_return"]) if label_obj["total_return"] else 0
  131. print(label)
  132. else:
  133. label = 0
  134. return label, label_dt
  135. def process_info(item_):
  136. """
  137. Insert data into MySql
  138. :param item_:
  139. """
  140. video_id, hour_dt = item_
  141. label_info = self.label_data
  142. if not label_info:
  143. print(label_info)
  144. # print(len(label_info))
  145. title = read_title(client=self.client, video_id=video_id)
  146. label, dt_daily = generate_label(video_id, hour_dt, label_info)
  147. insert_sql = f"""UPDATE lightgbm_data
  148. set video_title = '{title}', label = '{label}', daily_dt_str = '{dt_daily}'
  149. where video_id = '{video_id}'
  150. ;"""
  151. self.client_spider.update(insert_sql)
  152. select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label is NULL and hour_dt_str < '20240327';"
  153. init_data_tuple = self.client_spider.select(select_sql)
  154. init_list = list(init_data_tuple)
  155. for item in tqdm(init_list):
  156. # print(item)
  157. process_info(item)
  158. # with ThreadPoolExecutor(max_workers=10) as Pool:
  159. # Pool.map(process_info, init_list)
  160. class SpiderProcess(object):
  161. """
  162. Spider Data Process and Process data for lightgbm training
  163. """
  164. def __init__(self):
  165. self.client_spider = MySQLClientSpider()
  166. def spider_lop(self, video_id):
  167. """
  168. Spider lop = like / play
  169. :param video_id:
  170. :return:
  171. """
  172. sql = f"""SELECT like_cnt, play_cnt, duration from crawler_video where video_id = '{video_id}';"""
  173. try:
  174. like_cnt, play_cnt, duration = self.client_spider.select(sql)[0]
  175. lop = (like_cnt + 700) / (play_cnt + 18000)
  176. return lop, duration
  177. except Exception as e:
  178. print(video_id, "\t", e)
  179. return 0, 0
  180. def spider_data_produce(self):
  181. """
  182. 把 spider_duration 存储到数据库中
  183. :return:
  184. """
  185. return
  186. class UserProcess(object):
  187. """
  188. User Data Process
  189. """
  190. def __init__(self):
  191. self.client = MysqlClient()
  192. self.user_features = [
  193. "uid",
  194. "channel",
  195. "user_fans",
  196. "user_view_30",
  197. "user_share_30",
  198. "user_return_30",
  199. "user_rov",
  200. "user_str",
  201. "user_return_videos_30",
  202. "user_return_videos_3",
  203. "user_return_3",
  204. "user_view_3",
  205. "user_share_3",
  206. "address"
  207. ]
  208. def title_processor(self, video_id):
  209. """
  210. 通过 video_id 去获取title, 然后通过 title 再分词,把关键词作为 feature
  211. :param video_id: the video id
  212. :return: tag_list [tag, tag, tag, tag......]
  213. """
  214. sql = f"""SELECT title from wx_video where id = {video_id};"""
  215. try:
  216. title = self.client.select(sql)[0][0]
  217. keywords_textrank = jieba.analyse.textrank(title, topK=3)
  218. return list(keywords_textrank)
  219. except Exception as e:
  220. print(video_id, "\t", e)
  221. return []
  222. def user_data_process(self):
  223. """
  224. 把 user_return_3, user_view_3, user_share_3
  225. user_return_videos_3, user_return_videos_30
  226. address 存储到 mysql 数据库中
  227. :return:
  228. """
  229. user_path = '/data'
  230. if __name__ == "__main__":
  231. # D = DataProcessor()
  232. # D.producer()
  233. # parser = argparse.ArgumentParser() # 新建参数解释器对象
  234. # parser.add_argument("--mode")
  235. # parser.add_argument("--category")
  236. # parser.add_argument("--dtype", default="whole")
  237. # args = parser.parse_args()
  238. # mode = args.mode
  239. # category = args.category
  240. # dtype = args.dtype
  241. D = DataProcessor()
  242. D.producer()
  243. # if mode == "train":
  244. # print("Loading data and process for training.....")
  245. # D = DataProcessor(flag="train", ll=category)
  246. # D.producer("whole")
  247. # elif mode == "predict":
  248. # print("Loading data and process for prediction for each day......")
  249. # D = DataProcessor(flag="predict", ll=category)
  250. # if dtype == "single":
  251. # date_str = str(input("Please enter the date of the prediction"))
  252. # D.producer(date_str)
  253. # elif dtype == "days":
  254. # start_date_str = str(input("Please enter the start date of the prediction"))
  255. # end_date_str = str(input("Please enter the end date of the prediction"))
  256. # dt_list = generate_daily_strings(start_date=start_date_str, end_date=end_date_str)
  257. # for d in dt_list:
  258. # D.producer()