run_category_model_v1.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2024 StrayWarrior <i@straywarrior.com>
  6. import sys
  7. import os
  8. sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
  9. import time
  10. import json
  11. from datetime import datetime, timedelta
  12. import pandas as pd
  13. from argparse import ArgumentParser
  14. from long_articles.category_models import CategoryRegressionV1
  15. from long_articles.consts import reverse_category_name_map
  16. from common.database import MySQLManager
  17. from common import db_operation
  18. from common.logging import LOG
  19. from config.dev import Config
  20. def prepare_raw_data(dt_begin, dt_end):
  21. data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity',
  22. 'view_count_rate', 'category', 'read_avg',
  23. 'read_avg_rate']
  24. fields_str = ','.join(data_fields)
  25. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  26. sql = f"""
  27. SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end}
  28. AND similarity > 0 AND category IS NOT NULL AND read_avg > 500
  29. AND read_avg_rate BETWEEN 0 AND 3
  30. AND `index` = 1
  31. """
  32. rows = db_manager.select(sql)
  33. df = pd.DataFrame(rows, columns=data_fields)
  34. df = df.drop_duplicates(['dt', 'gh_id', 'title'])
  35. return df
  36. def clear_old_version(db_manager, dt):
  37. update_timestamp = int(time.time())
  38. sql = f"""
  39. UPDATE account_category
  40. SET status = 0, update_timestamp = {update_timestamp}
  41. WHERE dt < {dt} and status = 1
  42. """
  43. rows = db_manager.execute(sql)
  44. print(f"updated rows: {rows}")
  45. def get_last_version(db_manager, dt):
  46. sql = f"""
  47. SELECT gh_id, category_map
  48. FROM account_category
  49. WHERE dt = (SELECT max(dt) FROM account_category WHERE dt < {dt})
  50. """
  51. data = db_manager.select(sql)
  52. return data
  53. def compare_version(db_manager, dt_version, new_version, account_id_map):
  54. last_version = get_last_version(db_manager, dt_version)
  55. last_version = { entry[0]: json.loads(entry[1]) for entry in last_version }
  56. new_version = { entry['gh_id']: json.loads(entry['category_map']) for entry in new_version }
  57. # new record
  58. all_gh_ids = set(list(new_version.keys()) + list(last_version.keys()))
  59. for gh_id in all_gh_ids:
  60. account_name = account_id_map.get(gh_id, None)
  61. if gh_id not in last_version:
  62. print(f"new account {account_name}: {new_version[gh_id]}")
  63. elif gh_id not in new_version:
  64. print(f"old account {account_name}: {last_version[gh_id]}")
  65. else:
  66. new_cates = new_version[gh_id]
  67. old_cates = last_version[gh_id]
  68. for cate in new_cates:
  69. if cate not in old_cates:
  70. print(f"account {account_name} new cate: {cate} {new_cates[cate]}")
  71. for cate in old_cates:
  72. if cate not in new_cates:
  73. print(f"account {account_name} old cate: {cate} {old_cates[cate]}")
  74. def main():
  75. parser = ArgumentParser()
  76. parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database')
  77. parser.add_argument('--run-at', help='dt, also for version')
  78. parser.add_argument('--print-matrix', action='store_true')
  79. parser.add_argument('--print-residual', action='store_true')
  80. args = parser.parse_args()
  81. run_date = datetime.today()
  82. if args.run_at:
  83. run_date = datetime.strptime(args.run_at, "%Y%m%d")
  84. begin_dt = 20240914
  85. end_dt = (run_date - timedelta(1)).strftime("%Y%m%d")
  86. dt_version = end_dt
  87. LOG.info(f"data range: {begin_dt} - {end_dt}")
  88. raw_df = prepare_raw_data(begin_dt, end_dt)
  89. cate_model = CategoryRegressionV1()
  90. df = cate_model.preprocess_data(raw_df)
  91. if args.dry_run and args.print_matrix:
  92. cate_model.build_and_print_matrix(df)
  93. return
  94. create_timestamp = int(time.time())
  95. update_timestamp = create_timestamp
  96. records_to_save = []
  97. param_to_category_map = reverse_category_name_map
  98. account_ids = df['gh_id'].unique()
  99. account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
  100. .set_index('gh_id')['account_name'].to_dict()
  101. account_negative_cates = {k: [] for k in account_ids}
  102. P_VALUE_THRESHOLD = 0.15
  103. for account_id in account_ids:
  104. sub_df = df[df['gh_id'] == account_id]
  105. account_name = account_id_map[account_id]
  106. sample_count = len(sub_df)
  107. if sample_count < 5:
  108. continue
  109. params, t_stats, p_values = cate_model.run_ols_linear_regression(
  110. sub_df, args.print_residual, P_VALUE_THRESHOLD)
  111. current_record = {}
  112. current_record['dt'] = dt_version
  113. current_record['gh_id'] = account_id
  114. current_record['category_map'] = {}
  115. param_names = cate_model.get_param_names()
  116. for name, param, p_value in zip(param_names, params, p_values):
  117. cate_name = param_to_category_map.get(name, None)
  118. # 用于排序的品类相关性
  119. if abs(param) > 0.1 and p_value < P_VALUE_THRESHOLD and cate_name is not None:
  120. scale_factor = min(0.1 / p_value, 1)
  121. print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
  122. truncate_param = round(max(min(param, 0.25), -0.3) * scale_factor, 6)
  123. current_record['category_map'][cate_name] = truncate_param
  124. # 用于冷启文章分配的负向品类
  125. if param < -0.1 and cate_name is not None and p_value < P_VALUE_THRESHOLD:
  126. account_negative_cates[account_id].append(cate_name)
  127. # print((account_name, cate_name, param, p_value))
  128. if not current_record['category_map']:
  129. continue
  130. current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
  131. current_record['status'] = 1
  132. current_record['create_timestamp'] = create_timestamp
  133. current_record['update_timestamp'] = update_timestamp
  134. records_to_save.append(current_record)
  135. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  136. if args.dry_run:
  137. compare_version(db_manager, dt_version, records_to_save, account_id_map)
  138. return
  139. db_manager.batch_insert('account_category', records_to_save)
  140. clear_old_version(db_manager, dt_version)
  141. # 过滤空账号
  142. for account_id in [*account_negative_cates.keys()]:
  143. if not account_negative_cates[account_id]:
  144. account_negative_cates.pop(account_id)
  145. # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
  146. if __name__ == '__main__':
  147. main()