Browse Source

Update run_category_model_v1

StrayWarrior 4 tháng trước cách đây
mục cha
commit
3a4a4b8719
1 tập tin đã thay đổi với 14 bổ sung61 xóa
  1. 14 61
      run_category_model_v1.py

+ 14 - 61
run_category_model_v1.py

@@ -38,66 +38,16 @@ def prepare_raw_data(dt_begin, dt_end):
     df = df.drop_duplicates(['dt', 'gh_id', 'title'])
     return df
 
-def run_once(dt):
-    df = pd.read_excel('src/long_articles/20241101_read_rate_samples.xlsx')
-    df['read_avg'] = df['阅读均值']
-    df['read_avg_rate'] = df['阅读倍数']
-    df['dt'] = df['日期']
-    df['similarity'] = df['Similarity']
-    filter_condition = 'read_avg > 500 ' \
-        'and read_avg_rate > 0 and read_avg_rate < 3 ' \
-        'and dt > 20240914 and similarity > 0' 
-    df = df.query(filter_condition).copy()
-    #df = pd.read_excel('20241112-new-account-samples.xlsx')
 
-    cate_model = CategoryRegressionV1()
-
-    create_timestamp = int(time.time())
-    update_timestamp = create_timestamp
-
-    records_to_save = []
-    df = cate_model.preprocess_data(df)
-
-    param_to_category_map = cate_model.reverse_category_name_map
-    account_ids = df['ghID'].unique()
-    account_id_map = df[['账号名称', 'ghID']].drop_duplicates().set_index('ghID')['账号名称'].to_dict()
-
-    account_negative_cates = {k: [] for k in account_ids}
-    for account_id in account_ids:
-        sub_df = df[df['ghID'] == account_id]  
-        account_name = account_id_map[account_id]
-        sample_count = len(sub_df)
-        if sample_count < 5:
-            continue
-        params, t_stats, p_values = cate_model.run_ols_linear_regression(sub_df)
-        current_record = {}
-        current_record['dt'] = dt
-        current_record['gh_id'] = account_id
-        current_record['category_map'] = {}
-        param_names = cate_model.get_param_names()
-        for name, param, p_value in zip(param_names, params, p_values):
-            cate_name = param_to_category_map.get(name, None)
-            if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
-                #print(f"{account_id} {cate_name} {param:.3f} {p_value:.3f}")
-                current_record['category_map'][cate_name] = round(param, 6)
-            if param < -0.1 and cate_name is not None and p_value < 0.3:
-                account_negative_cates[account_id].append(cate_name)
-                print((account_name, cate_name, param, p_value))
-        current_record['category_map'] = json.dumps(current_record['category_map'], ensure_ascii=False)
-        current_record['status'] = 1
-        current_record['create_timestamp'] = create_timestamp
-        current_record['update_timestamp'] = update_timestamp
-        records_to_save.append(current_record) 
-    db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
-    #db_manager.batch_insert('account_category', records_to_save)
-
-    for account_id in [*account_negative_cates.keys()]:
-        if not account_negative_cates[account_id]:
-            account_negative_cates.pop(account_id)
-
-    print(json.dumps(account_negative_cates, ensure_ascii=False, indent=2))
-    for k, v in account_negative_cates.items():
-        print('{}\t{}'.format(k, ','.join(v)))
+def clear_old_version(db_manager, dt):
+    update_timestamp = int(time.time())
+    sql = f"""
+        UPDATE account_category
+        SET status = 0, update_timestamp = {update_timestamp}
+        WHERE dt < {dt} and status = 1
+    """
+    rows = db_manager.execute(sql)
+    print(f"updated rows: {rows}")
 
 
 def main():
@@ -119,8 +69,9 @@ def main():
     cate_model = CategoryRegressionV1()
     df = cate_model.preprocess_data(raw_df)
 
-    if args.dry_run:
+    if args.dry_run and False:
         cate_model.build(df)
+        return
 
     create_timestamp = int(time.time())
     update_timestamp = create_timestamp
@@ -150,7 +101,8 @@ def main():
             # 用于排序的品类相关性
             if abs(param) > 0.1 and p_value < 0.1 and cate_name is not None:
                 print(f"{account_id} {account_name} {cate_name} {param:.3f} {p_value:.3f}")
-                current_record['category_map'][cate_name] = round(param, 6)
+                truncate_param = round(max(min(param, 0.25), -0.3), 6)
+                current_record['category_map'][cate_name] = truncate_param
             # 用于冷启文章分配的负向品类
             if param < -0.1 and cate_name is not None and p_value < 0.3:
                 account_negative_cates[account_id].append(cate_name)
@@ -169,6 +121,7 @@ def main():
 
     db_manager = MySQLManager(Config().MYSQL_LONG_ARTICLES)
     db_manager.batch_insert('account_category', records_to_save)
+    clear_old_version(db_manager, dt_version)
 
     # 过滤空账号
     for account_id in [*account_negative_cates.keys()]: