浏览代码

Add category models

StrayWarrior 5 月之前
父节点
当前提交
6445c195c1
共有 5 个文件被更改,包括 327 次插入0 次删除
  1. 181 0
      run_category_model_v1.py
  2. 21 0
      src/common/db_operation.py
  3. 16 0
      src/common/logging.py
  4. 0 0
      src/long_articles/__init__.py
  5. 109 0
      src/long_articles/category_models.py

+ 181 - 0
run_category_model_v1.py

@@ -0,0 +1,181 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2024 StrayWarrior <i@straywarrior.com>
+
+
+import sys
+import os
+sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
+
+import time
+import json
+from datetime import datetime, timedelta
+import pandas as pd
+from argparse import ArgumentParser
+from long_articles.category_models import CategoryRegressionV1
+from common.database import MySQLManager
+from common import db_operation
+from common.logging import LOG
+from config.dev import Config
+
+
+def prepare_raw_data(dt_begin, dt_end):
+    data_fields = ['dt', 'gh_id', 'account_name', 'title', 'similarity',
+                   'view_count_rate', 'category', 'read_avg',
+                   'read_avg_rate']
+    fields_str = ','.join(data_fields)
+    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
+    sql = f"""
+        SELECT {fields_str} FROM datastat_score WHERE dt BETWEEN {dt_begin} AND {dt_end}
+            AND similarity > 0 AND category IS NOT NULL AND read_avg > 500
+            AND read_avg_rate BETWEEN 0 AND 3
+            AND `index` = 1
+        """
+    rows = db_manager.select(sql)
+    df = pd.DataFrame(rows, columns=data_fields)
+    df = df.drop_duplicates(['dt', 'gh_id', 'title'])
+    return df
+
+def run_once(dt):
+    df = pd.read_excel('src/long_articles/20241101_read_rate_samples.xlsx')
+    df['read_avg'] = df['阅读均值']
+    df['read_avg_rate'] = df['阅读倍数']
+    df['dt'] = df['日期']
+    df['similarity'] = df['Similarity']
+    filter_condition = 'read_avg > 500 ' \
+        'and read_avg_rate > 0 and read_avg_rate < 3 ' \
+        'and dt > 20240914 and similarity > 0' 
+    df = df.query(filter_condition).copy()
+    #df = pd.read_excel('20241112-new-account-samples.xlsx')
+
+    cate_model = CategoryRegressionV1()
+
+    create_timestamp = int(time.time())
+    update_timestamp = create_timestamp
+
+    records_to_save = []
+    df = cate_model.preprocess_data(df)
+
+    param_to_category_map = cate_model.reverse_category_name_map
+    account_ids = df['ghID'].unique()
+    account_id_map = df[['账号名称', 'ghID']].drop_duplicates().set_index('ghID')['账号名称'].to_dict()
+
+    account_negative_cates = {k: [] for k in account_ids}
+    for account_id in account_ids:
+        sub_df = df[df['ghID'] == account_id]  
+        account_name = account_id_map[account_id]
+        sample_count = len(sub_df)
+        if sample_count < 5:
+            continue
+        params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
+        current_record = {}
+        current_record['dt'] = dt
+        current_record['gh_id'] = account_id
+        current_record['category_map'] = {}
+        param_names = cate_model.get_param_names()
+        for name, param, p_value in zip(param_names, params, p_values):
+            cate_name = param_to_category_map.get(name, None)
+            if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
+                #print(f"{account_id} {cate_name} {param:.3f} {p_value:.3f}")
+                current_record['category_map'][cate_name] = round(param, 6)
+            if param < -0.1 and cate_name is not None and p_value < 0.3:
+                account_negative_cates[account_id].append(cate_name)
+                print((account_name, cate_name, param, p_value))
+        current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
+        current_record['status'] = 1
+        current_record['create_timestamp'] = create_timestamp
+        current_record['update_timestamp'] = update_timestamp
+        records_to_save.append(current_record) 
+    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
+    #db_manager.batch_insert('account_category', records_to_save)
+
+    for account_id in [*account_negative_cates.keys()]:
+        if not account_negative_cates[account_id]:
+            account_negative_cates.pop(account_id)
+
+    print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
+    for k, v in account_negative_cates.items():
+        print('{}\t{}'.format(k, ','.join(v)))
+
+
+def main():
+    parser = ArgumentParser()
+    parser.add_argument('-n', '--dry-run', action='store_true', help='do not update database')
+    parser.add_argument('--run-at', help='dt, also for version')
+    args = parser.parse_args()
+
+    run_date = datetime.today()
+    if args.run_at:
+        run_date = datetime.strptime(args.run_at, "%Y%m%d")
+    begin_dt = 20240914
+    end_dt = (run_date - timedelta(1)).strftime("%Y%m%d")
+    dt_version = end_dt
+    LOG.info(f"data range: {begin_dt} - {end_dt}")
+
+    raw_df = prepare_raw_data(begin_dt, end_dt)
+
+    cate_model = CategoryRegressionV1()
+    df = cate_model.preprocess_data(raw_df)
+
+    if args.dry_run:
+        cate_model.build(df)
+
+    create_timestamp = int(time.time())
+    update_timestamp = create_timestamp
+
+    records_to_save = []
+
+    param_to_category_map = cate_model.reverse_category_name_map
+    account_ids = df['gh_id'].unique()
+    account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
+        .set_index('gh_id')['account_name'].to_dict()
+
+    account_negative_cates = {k: [] for k in account_ids}
+    for account_id in account_ids:
+        sub_df = df[df['gh_id'] == account_id]  
+        account_name = account_id_map[account_id]
+        sample_count = len(sub_df)
+        if sample_count < 5:
+            continue
+        params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
+        current_record = {}
+        current_record['dt'] = dt_version
+        current_record['gh_id'] = account_id
+        current_record['category_map'] = {}
+        param_names = cate_model.get_param_names()
+        for name, param, p_value in zip(param_names, params, p_values):
+            cate_name = param_to_category_map.get(name, None)
+            # 用于排序的品类相关性
+            if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
+                print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
+                current_record['category_map'][cate_name] = round(param, 6)
+            # 用于冷启文章分配的负向品类
+            if param < -0.1 and cate_name is not None and p_value < 0.3:
+                account_negative_cates[account_id].append(cate_name)
+                # print((account_name, cate_name, param, p_value))
+        if not current_record['category_map']:
+            continue
+        current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
+        current_record['status'] = 1
+        current_record['create_timestamp'] = create_timestamp
+        current_record['update_timestamp'] = update_timestamp
+        records_to_save.append(current_record) 
+    if args.dry_run:
+        for record in records_to_save:
+            print(record)
+        return
+
+    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
+    db_manager.batch_insert('account_category', records_to_save)
+
+    # 过滤空账号
+    for account_id in [*account_negative_cates.keys()]:
+        if not account_negative_cates[account_id]:
+            account_negative_cates.pop(account_id)
+
+    # print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
+
+if __name__ == '__main__':
+    main()

+ 21 - 0
src/common/db_operation.py

@@ -0,0 +1,21 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2024 StrayWarrior <i@straywarrior.com>
+
+"""
+Common database operations
+"""
+
+def delete_old_version_logically(db_manger, table, dt_version, timestamp_sec,
+                                 old_status, new_status):
+    sql = f"""
+        UPDATE {table} SET status = {new_status}
+        WHERE status = {old_status}
+            AND dt <= {dt_version}
+            AND create_timestamp < {timestamp_sec}
+    """
+    rows = db_manger.execute(sql)
+    return rows
+

+ 16 - 0
src/common/logging.py

@@ -0,0 +1,16 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2024 StrayWarrior <i@straywarrior.com>
+
+import logging
+
+LOG = logging.getLogger()
+
+LOG.setLevel(logging.INFO)
+LOG.formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
+stream_handler = logging.StreamHandler()
+stream_handler.setLevel(logging.INFO)
+stream_handler.setFormatter(LOG.formatter)
+LOG.addHandler(stream_handler)

+ 0 - 0
src/long_articles/__init__.py


+ 109 - 0
src/long_articles/category_models.py

@@ -0,0 +1,109 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2024 StrayWarrior <i@straywarrior.com>
+
+"""
+Models for long article categories.
+"""
+
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+from sklearn.linear_model import LogisticRegression, LinearRegression
+from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
+from sklearn.metrics import mean_squared_error, r2_score
+import statsmodels.api as sm
+
+class CategoryRegressionV1:
+    def __init__(self):
+        self.features = [
+            #'ViewCountRate',
+            'CateOddities', 'CateFamily', 'CateHeartwarm',
+            'CateHistory', 'CateHealth', 'CateLifeKnowledge', 'CateGossip',
+            'CatePolitics', 'CateMilitary'
+        ]
+        self.category_name_map = {
+            '奇闻趣事': 'CateOddities',
+            '历史人物': 'CateHistory',
+            '家长里短': 'CateFamily',
+            '温情故事': 'CateHeartwarm',
+            '健康养生': 'CateHealth',
+            '生活知识': 'CateLifeKnowledge',
+            '名人八卦': 'CateGossip',
+            '政治新闻': 'CatePolitics',
+            '军事新闻': 'CateMilitary',
+        }
+        self.reverse_category_name_map = {
+            v: k for k, v in self.category_name_map.items()
+        }
+
+    def preprocess_data(self, df):
+        for cate in self.category_name_map:
+            colname = self.category_name_map[cate]
+            df[colname] = df['category'] == cate
+            df[colname] = df[colname].astype(int)
+
+        df['ClassY'] = df['read_avg_rate'] > 1
+        df['RegressionY'] = df['read_avg_rate']
+        return df
+
+    def build_and_print(self, df, account_name):
+        if account_name is not None:
+            sub_df = df[df['account_name'] == account_name]  
+        else:
+            sub_df = df
+        if len(sub_df) < 5:
+            return
+        sample_count = len(sub_df)
+        params, t_stats, p_values = self.run_ols_linear_regression(sub_df)
+        row = f'{account_name}\t{sample_count}'
+        for param, p_value in zip(params, p_values):
+            row += f'\t{param:.3f}\t{p_value:.3f}'
+        print(row)
+
+    def build(self, df):
+        p_value_column_names = '\t'.join([name + "\tp-" + name for name in
+                                          ['bias'] + self.features])
+        print('account\tsamples\t{}'.format(p_value_column_names))
+        self.build_and_print(df, None)
+        for account_name in df['account_name'].unique():
+            self.build_and_print(df, account_name)
+
+    def get_param_names(self):
+        return ['bias'] + self.features
+
+    def run_ols_linear_regression(self, df):
+        X = df[self.features]  # 特征列
+        y = df['RegressionY']  # 目标变量
+        X = sm.add_constant(X)
+
+        model = sm.OLS(y, X).fit()
+
+        params = model.params
+        t_stats = model.tvalues
+        p_values = model.pvalues
+        conf_int = model.conf_int()
+
+        return params, t_stats, p_values
+
+def main():
+    df = pd.read_excel('20241101_read_rate_samples.xlsx')  # 如果数据来自CSV文件
+    df['read_avg'] = df['阅读均值']
+    df['read_avg_rate'] = df['阅读倍数']
+    df['dt'] = df['日期']
+    df['similarity'] = df['Similarity']
+    filter_condition = 'read_avg > 500 ' \
+        'and read_avg_rate > 0 and read_avg_rate < 3 ' \
+        'and dt > 20240914 and similarity > 0' 
+    df = df.query(filter_condition).copy()
+
+    m_cate = CategoryRegressionV1()
+
+    df = m_cate.preprocess_data(df)
+    m_cate.build(df)
+
+
+if __name__ == '__main__':
+    main()