run_category_model_v1.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2024 StrayWarrior <i@straywarrior.com>
  6. import sys
  7. import os
  8. sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
  9. import time
  10. import json
  11. from datetime import datetime, timedelta
  12. import pandas as pd
  13. from argparse import ArgumentParser
  14. from long_articles.category_models import CategoryRegressionV1
  15. from long_articles.consts import category_feature_v2, category_name_map_v2, reverse_category_name_map_v2
  16. from common.database import MySQLManager
  17. from common import db_operation
  18. from common.logging import LOG
  19. from config.dev import Config
  20. NIGHT_ACCOUNTS = ('gh_12523d39d809','gh_df4a630c04db','gh_f67df16f4670','gh_ca44517edda9','gh_a66c1316fd5e','gh_4242c478bbba','gh_60b0c23fcc7c','gh_33b3470784fc','gh_ec1bcb283daf','gh_234ab9ff490d','gh_7715a626a4c6','gh_1bfe1d257728','gh_9db5e3ac2c93','gh_9d1ae5f9ceac','gh_7208b813f16d','gh_e56ddf195d91','gh_a43aecffe81b','gh_d4a7d2ce54fd','gh_c2b458818b09','gh_349a57ef1c44','gh_89bfe54ad90f','gh_b929ed680b62','gh_f8e8a931ff56','gh_916f4fad5ce0','gh_0d7c5f4c38a9','gh_bceef3f747c2','gh_706456719017','gh_fd51a5e33fc6','gh_5372093f5fb0','gh_957ff8e08e1b','gh_64fc629d3ec2','gh_c8b69797912a','gh_6909b38ad95f','gh_1e69a1b4dc1a','gh_0763523103e4','gh_9b83a9ad7da0','gh_82b416f27698','gh_a60647e98cd9','gh_3ce2fa1956ea','gh_44127c197525','gh_06834aba13a5','gh_c33809af68bc','gh_82cf39ef616e','gh_a342ef23c48e','gh_c9cc1471af7d','gh_291ec369f017','gh_810a439f320a','gh_00f942061a0d','gh_7662653b0e77','gh_d192d757b606','gh_391702d26b3b','gh_3e90f421c974','gh_30d189fe56c7','gh_7ebfbbf675ee','gh_3f84c2b9a1a2','gh_bccbe3681e22','gh_005fc1cb4b73','gh_21d120007b64','gh_3d5f24fd3311','gh_3621aaa6c4a0','gh_aee2dca32701','gh_c25c6040c4b2','gh_641019d44876','gh_95ba63e5cf18','gh_efd90dcf48ac','gh_5e1464b76ff6','gh_5765f834684c','gh_81bec2f4f577','gh_401396006e13','gh_7c33726c5147','gh_bbd8a52ba98b','gh_f74ca3104604'
  21. )
  22. def prepare_raw_data(dt_begin, dt_end):
  23. data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity',
  24. 'view_count_rate', 'category', 'read_avg',
  25. 'read_avg_rate', 'first_pub_interval', '`index`']
  26. fields_str = ','.join(data_fields)
  27. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  28. night_accounts_condition = str(NIGHT_ACCOUNTS)
  29. sql = f"""
  30. SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end}
  31. AND similarity > 0 AND category IS NOT NULL AND read_avg > 500
  32. AND read_avg_rate BETWEEN 0.3 AND 3 AND view_count_rate > 0
  33. AND `index` in (1, 2)
  34. AND (FROM_UNIXTIME(coalesce(publish_timestamp, 0), '%H') < '15'
  35. OR gh_id in {night_accounts_condition})
  36. AND dt NOT BETWEEN 20250105 AND 20250215
  37. """
  38. rows = db_manager.select(sql)
  39. df = pd.DataFrame(rows, columns=data_fields)
  40. df.rename(columns={'`index`': 'index'}, inplace=True)
  41. df = df.drop_duplicates(['dt', 'gh_id', 'title'])
  42. return df
  43. def clear_old_version(db_manager, dt):
  44. update_timestamp = int(time.time())
  45. sql = f"""
  46. UPDATE account_category
  47. SET status = 0, update_timestamp = {update_timestamp}
  48. WHERE dt < {dt} and status = 1 and version = 2
  49. """
  50. rows = db_manager.execute(sql)
  51. print(f"updated rows for clear: {rows}")
  52. def get_last_version(db_manager, dt):
  53. sql = f"""
  54. SELECT gh_id, category_map
  55. FROM account_category
  56. WHERE dt = (SELECT max(dt) FROM account_category WHERE dt < {dt} AND
  57. status = 1)
  58. """
  59. data = db_manager.select(sql)
  60. return data
  61. def compare_version(db_manager, dt_version, new_version, account_id_map):
  62. last_version = get_last_version(db_manager, dt_version)
  63. last_version = { entry[0]: json.loads(entry[1]) for entry in last_version }
  64. new_version = { entry['gh_id']: json.loads(entry['category_map']) for entry in new_version }
  65. # new record
  66. all_gh_ids = set(list(new_version.keys()) + list(last_version.keys()))
  67. for gh_id in all_gh_ids:
  68. account_name = account_id_map.get(gh_id, None)
  69. if gh_id not in last_version:
  70. print(f"new account {account_name}: {new_version[gh_id]}")
  71. elif gh_id not in new_version:
  72. print(f"old account {account_name}: {last_version[gh_id]}")
  73. else:
  74. new_cates = new_version[gh_id]
  75. old_cates = last_version[gh_id]
  76. for cate in new_cates:
  77. if cate not in old_cates:
  78. print(f"account {account_name} new cate: {cate} {new_cates[cate]}")
  79. for cate in old_cates:
  80. if cate not in new_cates:
  81. print(f"account {account_name} old cate: {cate} {old_cates[cate]}")
  82. def main():
  83. parser = ArgumentParser()
  84. parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database')
  85. parser.add_argument('--run-at', help='dt, also for version')
  86. parser.add_argument('--print-matrix', action='store_true')
  87. parser.add_argument('--print-residual', action='store_true')
  88. args = parser.parse_args()
  89. run_date = datetime.today()
  90. if args.run_at:
  91. run_date = datetime.strptime(args.run_at, "%Y%m%d")
  92. begin_dt = 20240914
  93. end_dt = (run_date - timedelta(1)).strftime("%Y%m%d")
  94. dt_version = end_dt
  95. LOG.info(f"data range: {begin_dt} - {end_dt}")
  96. raw_df = prepare_raw_data(begin_dt, end_dt)
  97. cate_model = CategoryRegressionV1(category_feature_v2, category_name_map_v2)
  98. df = cate_model.preprocess_data(raw_df)
  99. if args.dry_run and args.print_matrix:
  100. cate_model.build_and_print_matrix(df)
  101. return
  102. create_timestamp = int(time.time())
  103. update_timestamp = create_timestamp
  104. records_to_save = []
  105. param_to_category_map = reverse_category_name_map_v2
  106. account_ids = df['gh_id'].unique()
  107. account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
  108. .set_index('gh_id')['account_name'].to_dict()
  109. account_negative_cates = {k: [] for k in account_ids}
  110. P_VALUE_THRESHOLD = 0.15
  111. for account_id in account_ids:
  112. sub_df = df[df['gh_id'] == account_id]
  113. account_name = account_id_map[account_id]
  114. sample_count = len(sub_df)
  115. if sample_count < 5:
  116. continue
  117. params, t_stats, p_values = cate_model.run_ols_linear_regression(
  118. sub_df, args.print_residual, P_VALUE_THRESHOLD)
  119. current_record = {}
  120. current_record['dt'] = dt_version
  121. current_record['gh_id'] = account_id
  122. current_record['category_map'] = {}
  123. param_names = cate_model.get_param_names()
  124. for name, param, p_value in zip(param_names, params, p_values):
  125. cate_name = param_to_category_map.get(name, None)
  126. # 用于排序的品类相关性
  127. if abs(param) > 0.1 and p_value < P_VALUE_THRESHOLD and cate_name is not None:
  128. scale_factor = min(0.1 / p_value, 1)
  129. print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
  130. truncate_param = round(max(min(param, 0.25), -0.25) * scale_factor, 6)
  131. current_record['category_map'][cate_name] = truncate_param
  132. # 用于冷启文章分配的负向品类
  133. if param < -0.1 and cate_name is not None and p_value < P_VALUE_THRESHOLD:
  134. account_negative_cates[account_id].append(cate_name)
  135. # print((account_name, cate_name, param, p_value))
  136. if not current_record['category_map']:
  137. continue
  138. current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
  139. current_record['status'] = 1
  140. current_record['version'] = 2
  141. current_record['create_timestamp'] = create_timestamp
  142. current_record['update_timestamp'] = update_timestamp
  143. records_to_save.append(current_record)
  144. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  145. if args.dry_run:
  146. compare_version(db_manager, dt_version, records_to_save, account_id_map)
  147. return
  148. rows = db_manager.batch_insert('account_category', records_to_save, ignore=True)
  149. if rows != len(records_to_save):
  150. for record in records_to_save:
  151. sql = f"""
  152. UPDATE account_category
  153. SET category_map = '{record['category_map']}',
  154. update_timestamp = {record['update_timestamp']}
  155. WHERE dt = {record['dt']} AND gh_id = '{record['gh_id']}'
  156. AND category_map != '{record['category_map']}'
  157. AND version = 2
  158. """
  159. update_rows = db_manager.execute(sql)
  160. print(f"updated rows: {update_rows}, {record['gh_id']}")
  161. clear_old_version(db_manager, dt_version)
  162. # 过滤空账号
  163. for account_id in [*account_negative_cates.keys()]:
  164. if not account_negative_cates[account_id]:
  165. account_negative_cates.pop(account_id)
  166. # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
  167. if __name__ == '__main__':
  168. pd.set_option('display.max_columns', None)
  169. pd.set_option('display.max_rows', None)
  170. main()