#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 # # Copyright © 2024 StrayWarrior import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) import time import json from datetime import datetime, timedelta import pandas as pd from argparse import ArgumentParser from long_articles.category_models import CategoryRegressionV1 from long_articles.consts import reverse_category_name_map from common.database import MySQLManager from common import db_operation from common.logging import LOG from config.dev import Config def prepare_raw_data(dt_begin, dt_end): data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity', 'view_count_rate', 'category', 'read_avg', 'read_avg_rate'] fields_str = ','.join(data_fields) db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES) sql = f""" SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end} AND similarity > 0 AND category IS NOT NULL AND read_avg > 500 AND read_avg_rate BETWEEN 0 AND 3 AND `index` = 1 """ rows = db_manager.select(sql) df = pd.DataFrame(rows, columns=data_fields) df = df.drop_duplicates(['dt', 'gh_id', 'title']) return df def clear_old_version(db_manager, dt): update_timestamp = int(time.time()) sql = f""" UPDATE account_category SET status = 0, update_timestamp = {update_timestamp} WHERE dt < {dt} and status = 1 """ rows = db_manager.execute(sql) print(f"updated rows: {rows}") def get_last_version(db_manager, dt): sql = f""" SELECT gh_id, category_map FROM account_category WHERE dt = (SELECT max(dt) FROM account_category WHERE dt < {dt}) """ data = db_manager.select(sql) return data def compare_version(db_manager, dt_version, new_version, account_id_map): last_version = get_last_version(db_manager, dt_version) last_version = { entry[0]: json.loads(entry[1]) for entry in last_version } new_version = { entry['gh_id']: json.loads(entry['category_map']) for entry in new_version } # new record all_gh_ids = set(list(new_version.keys()) + list(last_version.keys())) for gh_id in all_gh_ids: account_name = account_id_map[gh_id] if gh_id not in last_version: print(f"new account {account_name}: {new_version[gh_id]}") elif gh_id not in new_version: print(f"old account {account_name}: {last_version[gh_id]}") else: new_cates = new_version[gh_id] old_cates = last_version[gh_id] for cate in new_cates: if cate not in old_cates: print(f"account {account_name} new cate: {cate} {new_cates[cate]}") for cate in old_cates: if cate not in new_cates: print(f"account {account_name} old cate: {cate} {old_cates[cate]}") def main(): parser = ArgumentParser() parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database') parser.add_argument('--run-at', help='dt, also for version') args = parser.parse_args() run_date = datetime.today() if args.run_at: run_date = datetime.strptime(args.run_at, "%Y%m%d") begin_dt = 20240914 end_dt = (run_date - timedelta(1)).strftime("%Y%m%d") dt_version = end_dt LOG.info(f"data range: {begin_dt} - {end_dt}") raw_df = prepare_raw_data(begin_dt, end_dt) cate_model = CategoryRegressionV1() df = cate_model.preprocess_data(raw_df) if args.dry_run and False: cate_model.build(df) return create_timestamp = int(time.time()) update_timestamp = create_timestamp records_to_save = [] param_to_category_map = reverse_category_name_map account_ids = df['gh_id'].unique() account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \ .set_index('gh_id')['account_name'].to_dict() account_negative_cates = {k: [] for k in account_ids} for account_id in account_ids: sub_df = df[df['gh_id'] == account_id] account_name = account_id_map[account_id] sample_count = len(sub_df) if sample_count < 5: continue print_error = False params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df, print_error) current_record = {} current_record['dt'] = dt_version current_record['gh_id'] = account_id current_record['category_map'] = {} param_names = cate_model.get_param_names() for name, param, p_value in zip(param_names, params, p_values): cate_name = param_to_category_map.get(name, None) # 用于排序的品类相关性 if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None: print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}") truncate_param = round(max(min(param, 0.25), -0.3), 6) current_record['category_map'][cate_name] = truncate_param # 用于冷启文章分配的负向品类 if param < -0.1 and cate_name is not None and p_value < 0.3: account_negative_cates[account_id].append(cate_name) # print((account_name, cate_name, param, p_value)) if not current_record['category_map']: continue current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False) current_record['status'] = 1 current_record['create_timestamp'] = create_timestamp current_record['update_timestamp'] = update_timestamp records_to_save.append(current_record) db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES) if args.dry_run: compare_version(db_manager, dt_version, records_to_save, account_id_map) return db_manager.batch_insert('account_category', records_to_save) clear_old_version(db_manager, dt_version) # 过滤空账号 for account_id in [*account_negative_cates.keys()]: if not account_negative_cates[account_id]: account_negative_cates.pop(account_id) # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2)) if __name__ == '__main__': main()