run_category_model_v1.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2024 StrayWarrior <i@straywarrior.com>
  6. import sys
  7. import os
  8. sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
  9. import time
  10. import json
  11. from datetime import datetime, timedelta
  12. import pandas as pd
  13. from argparse import ArgumentParser
  14. from long_articles.category_models import CategoryRegressionV1
  15. from common.database import MySQLManager
  16. from common import db_operation
  17. from common.logging import LOG
  18. from config.dev import Config
  19. def prepare_raw_data(dt_begin, dt_end):
  20. data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity',
  21. 'view_count_rate', 'category', 'read_avg',
  22. 'read_avg_rate']
  23. fields_str = ','.join(data_fields)
  24. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  25. sql = f"""
  26. SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end}
  27. AND similarity > 0 AND category IS NOT NULL AND read_avg > 500
  28. AND read_avg_rate BETWEEN 0 AND 3
  29. AND `index` = 1
  30. """
  31. rows = db_manager.select(sql)
  32. df = pd.DataFrame(rows, columns=data_fields)
  33. df = df.drop_duplicates(['dt', 'gh_id', 'title'])
  34. return df
  35. def run_once(dt):
  36. df = pd.read_excel('src/long_articles/20241101_read_rate_samples.xlsx')
  37. df['read_avg'] = df['阅读均值']
  38. df['read_avg_rate'] = df['阅读倍数']
  39. df['dt'] = df['日期']
  40. df['similarity'] = df['Similarity']
  41. filter_condition = 'read_avg > 500 ' \
  42. 'and read_avg_rate > 0 and read_avg_rate < 3 ' \
  43. 'and dt > 20240914 and similarity > 0'
  44. df = df.query(filter_condition).copy()
  45. #df = pd.read_excel('20241112-new-account-samples.xlsx')
  46. cate_model = CategoryRegressionV1()
  47. create_timestamp = int(time.time())
  48. update_timestamp = create_timestamp
  49. records_to_save = []
  50. df = cate_model.preprocess_data(df)
  51. param_to_category_map = cate_model.reverse_category_name_map
  52. account_ids = df['ghID'].unique()
  53. account_id_map = df[['账号名称', 'ghID']].drop_duplicates().set_index('ghID')['账号名称'].to_dict()
  54. account_negative_cates = {k: [] for k in account_ids}
  55. for account_id in account_ids:
  56. sub_df = df[df['ghID'] == account_id]
  57. account_name = account_id_map[account_id]
  58. sample_count = len(sub_df)
  59. if sample_count < 5:
  60. continue
  61. params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
  62. current_record = {}
  63. current_record['dt'] = dt
  64. current_record['gh_id'] = account_id
  65. current_record['category_map'] = {}
  66. param_names = cate_model.get_param_names()
  67. for name, param, p_value in zip(param_names, params, p_values):
  68. cate_name = param_to_category_map.get(name, None)
  69. if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
  70. #print(f"{account_id} {cate_name} {param:.3f} {p_value:.3f}")
  71. current_record['category_map'][cate_name] = round(param, 6)
  72. if param < -0.1 and cate_name is not None and p_value < 0.3:
  73. account_negative_cates[account_id].append(cate_name)
  74. print((account_name, cate_name, param, p_value))
  75. current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
  76. current_record['status'] = 1
  77. current_record['create_timestamp'] = create_timestamp
  78. current_record['update_timestamp'] = update_timestamp
  79. records_to_save.append(current_record)
  80. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  81. #db_manager.batch_insert('account_category', records_to_save)
  82. for account_id in [*account_negative_cates.keys()]:
  83. if not account_negative_cates[account_id]:
  84. account_negative_cates.pop(account_id)
  85. print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
  86. for k, v in account_negative_cates.items():
  87. print('{}\t{}'.format(k, ','.join(v)))
  88. def main():
  89. parser = ArgumentParser()
  90. parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database')
  91. parser.add_argument('--run-at', help='dt, also for version')
  92. args = parser.parse_args()
  93. run_date = datetime.today()
  94. if args.run_at:
  95. run_date = datetime.strptime(args.run_at, "%Y%m%d")
  96. begin_dt = 20240914
  97. end_dt = (run_date - timedelta(1)).strftime("%Y%m%d")
  98. dt_version = end_dt
  99. LOG.info(f"data range: {begin_dt} - {end_dt}")
  100. raw_df = prepare_raw_data(begin_dt, end_dt)
  101. cate_model = CategoryRegressionV1()
  102. df = cate_model.preprocess_data(raw_df)
  103. if args.dry_run:
  104. cate_model.build(df)
  105. create_timestamp = int(time.time())
  106. update_timestamp = create_timestamp
  107. records_to_save = []
  108. param_to_category_map = cate_model.reverse_category_name_map
  109. account_ids = df['gh_id'].unique()
  110. account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
  111. .set_index('gh_id')['account_name'].to_dict()
  112. account_negative_cates = {k: [] for k in account_ids}
  113. for account_id in account_ids:
  114. sub_df = df[df['gh_id'] == account_id]
  115. account_name = account_id_map[account_id]
  116. sample_count = len(sub_df)
  117. if sample_count < 5:
  118. continue
  119. params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
  120. current_record = {}
  121. current_record['dt'] = dt_version
  122. current_record['gh_id'] = account_id
  123. current_record['category_map'] = {}
  124. param_names = cate_model.get_param_names()
  125. for name, param, p_value in zip(param_names, params, p_values):
  126. cate_name = param_to_category_map.get(name, None)
  127. # 用于排序的品类相关性
  128. if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
  129. print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
  130. current_record['category_map'][cate_name] = round(param, 6)
  131. # 用于冷启文章分配的负向品类
  132. if param < -0.1 and cate_name is not None and p_value < 0.3:
  133. account_negative_cates[account_id].append(cate_name)
  134. # print((account_name, cate_name, param, p_value))
  135. if not current_record['category_map']:
  136. continue
  137. current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
  138. current_record['status'] = 1
  139. current_record['create_timestamp'] = create_timestamp
  140. current_record['update_timestamp'] = update_timestamp
  141. records_to_save.append(current_record)
  142. if args.dry_run:
  143. for record in records_to_save:
  144. print(record)
  145. return
  146. db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
  147. db_manager.batch_insert('account_category', records_to_save)
  148. # 过滤空账号
  149. for account_id in [*account_negative_cates.keys()]:
  150. if not account_negative_cates[account_id]:
  151. account_negative_cates.pop(account_id)
  152. # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
  153. if __name__ == '__main__':
  154. main()