Quellcode durchsuchen

Update run_category_model_v1: compare model results

StrayWarrior vor 4 Monaten
Ursprung
Commit
52c3e48505
1 geänderte Dateien mit 40 neuen und 7 gelöschten Zeilen
  1. 40 7
      run_category_model_v1.py

+ 40 - 7
run_category_model_v1.py

@@ -15,6 +15,7 @@ from datetime import datetime, timedelta
 import pandas as pd
 from argparse import ArgumentParser
 from long_articles.category_models import CategoryRegressionV1
+from long_articles.consts import reverse_category_name_map
 from common.database import MySQLManager
 from common import db_operation
 from common.logging import LOG
@@ -49,6 +50,37 @@ def clear_old_version(db_manager, dt):
     rows = db_manager.execute(sql)
     print(f"updated rows: {rows}")
 
+def get_last_version(db_manager, dt):
+    sql = f"""
+        SELECT gh_id, category_map
+        FROM account_category
+        WHERE dt = (SELECT max(dt) FROM account_category WHERE dt < {dt})
+    """
+    data = db_manager.select(sql)
+    return data
+
+def compare_version(db_manager, dt_version, new_version, account_id_map):
+    last_version = get_last_version(db_manager, dt_version)
+    last_version = { entry[0]: json.loads(entry[1]) for entry in last_version }
+    new_version = { entry['gh_id']: json.loads(entry['category_map']) for entry in new_version }
+    # new record
+    all_gh_ids = set(list(new_version.keys()) + list(last_version.keys()))
+    for gh_id in all_gh_ids:
+        account_name = account_id_map[gh_id]
+        if gh_id not in last_version:
+            print(f"new account {account_name}: {new_version[gh_id]}")
+        elif gh_id not in new_version:
+            print(f"old account {account_name}: {last_version[gh_id]}")
+        else:
+            new_cates = new_version[gh_id]
+            old_cates = last_version[gh_id]
+            for cate in new_cates:
+                if cate not in old_cates:
+                    print(f"account {account_name} new cate: {cate} {new_cates[cate]}")
+            for cate in old_cates:
+                if cate not in new_cates:
+                    print(f"account {account_name} old cate: {cate} {old_cates[cate]}")
+
 
 def main():
     parser = ArgumentParser()
@@ -78,19 +110,20 @@ def main():
 
     records_to_save = []
 
-    param_to_category_map = cate_model.reverse_category_name_map
+    param_to_category_map = reverse_category_name_map
     account_ids = df['gh_id'].unique()
     account_id_map = df[['account_name', 'gh_id']].drop_duplicates() \
         .set_index('gh_id')['account_name'].to_dict()
 
     account_negative_cates = {k: [] for k in account_ids}
     for account_id in account_ids:
-        sub_df = df[df['gh_id'] == account_id]  
+        sub_df = df[df['gh_id'] == account_id]
         account_name = account_id_map[account_id]
         sample_count = len(sub_df)
         if sample_count < 5:
             continue
-        params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
+        print_error = False
+        params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df, print_error)
         current_record = {}
         current_record['dt'] = dt_version
         current_record['gh_id'] = account_id
@@ -113,13 +146,13 @@ def main():
         current_record['status'] = 1
         current_record['create_timestamp'] = create_timestamp
         current_record['update_timestamp'] = update_timestamp
-        records_to_save.append(current_record) 
+        records_to_save.append(current_record)
+
+    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
     if args.dry_run:
-        for record in records_to_save:
-            print(record)
+        compare_version(db_manager, dt_version, records_to_save, account_id_map)
         return
 
-    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
     db_manager.batch_insert('account_category', records_to_save)
     clear_old_version(db_manager, dt_version)