瀏覽代碼

Update category_models: support print error for analysis

StrayWarrior 4 月之前
父節點
當前提交
d60d3f41ec
共有 1 個文件被更改,包括 27 次插入20 次删除
  1. 27 20
      src/long_articles/category_models.py

+ 27 - 20
src/long_articles/category_models.py

@@ -15,33 +15,20 @@ from sklearn.linear_model import LogisticRegression, LinearRegression
 from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
 from sklearn.metrics import mean_squared_error, r2_score
 import statsmodels.api as sm
+from .consts import category_name_map, reverse_category_name_map
 
 class CategoryRegressionV1:
     def __init__(self):
         self.features = [
-            #'ViewCountRate',
             'CateOddities', 'CateFamily', 'CateHeartwarm',
             'CateHistory', 'CateHealth', 'CateLifeKnowledge', 'CateGossip',
-            'CatePolitics', 'CateMilitary'
+            'CatePolitics', 'CateMilitary', 'CateMovie', 'CateSociety',
+            'view_count_rate'
         ]
-        self.category_name_map = {
-            '奇闻趣事': 'CateOddities',
-            '历史人物': 'CateHistory',
-            '家长里短': 'CateFamily',
-            '温情故事': 'CateHeartwarm',
-            '健康养生': 'CateHealth',
-            '生活知识': 'CateLifeKnowledge',
-            '名人八卦': 'CateGossip',
-            '政治新闻': 'CatePolitics',
-            '军事新闻': 'CateMilitary',
-        }
-        self.reverse_category_name_map = {
-            v: k for k, v in self.category_name_map.items()
-        }
 
     def preprocess_data(self, df):
-        for cate in self.category_name_map:
-            colname = self.category_name_map[cate]
+        for cate in category_name_map:
+            colname = category_name_map[cate]
             df[colname] = df['category'] == cate
             df[colname] = df[colname].astype(int)
 
@@ -51,7 +38,7 @@ class CategoryRegressionV1:
 
     def build_and_print(self, df, account_name):
         if account_name is not None:
-            sub_df = df[df['account_name'] == account_name]  
+            sub_df = df[df['account_name'] == account_name]
         else:
             sub_df = df
         if len(sub_df) < 5:
@@ -74,7 +61,7 @@ class CategoryRegressionV1:
     def get_param_names(self):
         return ['bias'] + self.features
 
-    def run_ols_linear_regression(self, df):
+    def run_ols_linear_regression(self, df, print_residual=False):
         X = df[self.features]  # 特征列
         y = df['RegressionY']  # 目标变量
         X = sm.add_constant(X)
@@ -86,6 +73,26 @@ class CategoryRegressionV1:
         p_values = model.pvalues
         conf_int = model.conf_int()
 
+        if print_residual:
+            predict_y = model.predict(X)
+            residuals = y - predict_y
+            new_x = df[['title', 'category']].copy()
+            new_x['residual'] = residuals
+            new_x['y'] = y
+            for index, row in new_x.iterrows():
+                param_name = category_name_map.get(row['category'], None)
+                if not param_name:
+                    continue
+                param_index = self.features.index(param_name) + 1
+                param = params[param_index]
+                p_value = p_values[param_index]
+                if p_value < 0.1:
+                    print(f"{row['y']:.3f}\t{row['residual']:.3f}\t{row['category']}\t{param:.2f}\t{row['title']}")
+            r_min = residuals.min()
+            r_max = residuals.max()
+            r_avg = residuals.mean()
+            print(f"residuals min: {r_min:.3f}, max: {r_max:.3f}, mean: {r_avg:.3f}")
+
         return params, t_stats, p_values
 
 def main():