123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- #
- # Copyright © 2024 StrayWarrior <i@straywarrior.com>
- import sys
- import os
- sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
- import time
- import json
- from datetime import datetime, timedelta
- import pandas as pd
- from argparse import ArgumentParser
- from long_articles.category_models import CategoryRegressionV1
- from common.database import MySQLManager
- from common import db_operation
- from common.logging import LOG
- from config.dev import Config
- def prepare_raw_data(dt_begin, dt_end):
- data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity',
- 'view_count_rate', 'category', 'read_avg',
- 'read_avg_rate']
- fields_str = ','.join(data_fields)
- db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
- sql = f"""
- SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end}
- AND similarity > 0 AND category IS NOT NULL AND read_avg > 500
- AND read_avg_rate BETWEEN 0 AND 3
- AND `index` = 1
- """
- rows = db_manager.select(sql)
- df = pd.DataFrame(rows, columns=data_fields)
- df = df.drop_duplicates(['dt', 'gh_id', 'title'])
- return df
- def clear_old_version(db_manager, dt):
- update_timestamp = int(time.time())
- sql = f"""
- UPDATE account_category
- SET status = 0, update_timestamp = {update_timestamp}
- WHERE dt < {dt} and status = 1
- """
- rows = db_manager.execute(sql)
- print(f"updated rows: {rows}")
- def main():
- parser = ArgumentParser()
- parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database')
- parser.add_argument('--run-at', help='dt, also for version')
- args = parser.parse_args()
- run_date = datetime.today()
- if args.run_at:
- run_date = datetime.strptime(args.run_at, "%Y%m%d")
- begin_dt = 20240914
- end_dt = (run_date - timedelta(1)).strftime("%Y%m%d")
- dt_version = end_dt
- LOG.info(f"data range: {begin_dt} - {end_dt}")
- raw_df = prepare_raw_data(begin_dt, end_dt)
- cate_model = CategoryRegressionV1()
- df = cate_model.preprocess_data(raw_df)
- if args.dry_run and False:
- cate_model.build(df)
- return
- create_timestamp = int(time.time())
- update_timestamp = create_timestamp
- records_to_save = []
- param_to_category_map = cate_model.reverse_category_name_map
- account_ids = df['gh_id'].unique()
- account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
- .set_index('gh_id')['account_name'].to_dict()
- account_negative_cates = {k: [] for k in account_ids}
- for account_id in account_ids:
- sub_df = df[df['gh_id'] == account_id]
- account_name = account_id_map[account_id]
- sample_count = len(sub_df)
- if sample_count < 5:
- continue
- params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
- current_record = {}
- current_record['dt'] = dt_version
- current_record['gh_id'] = account_id
- current_record['category_map'] = {}
- param_names = cate_model.get_param_names()
- for name, param, p_value in zip(param_names, params, p_values):
- cate_name = param_to_category_map.get(name, None)
- # 用于排序的品类相关性
- if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
- print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
- truncate_param = round(max(min(param, 0.25), -0.3), 6)
- current_record['category_map'][cate_name] = truncate_param
- # 用于冷启文章分配的负向品类
- if param < -0.1 and cate_name is not None and p_value < 0.3:
- account_negative_cates[account_id].append(cate_name)
- # print((account_name, cate_name, param, p_value))
- if not current_record['category_map']:
- continue
- current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
- current_record['status'] = 1
- current_record['create_timestamp'] = create_timestamp
- current_record['update_timestamp'] = update_timestamp
- records_to_save.append(current_record)
- if args.dry_run:
- for record in records_to_save:
- print(record)
- return
- db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
- db_manager.batch_insert('account_category', records_to_save)
- clear_old_version(db_manager, dt_version)
- # 过滤空账号
- for account_id in [*account_negative_cates.keys()]:
- if not account_negative_cates[account_id]:
- account_negative_cates.pop(account_id)
- # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
- if __name__ == '__main__':
- main()
|