Browse Source

上传读取标题功能

罗俊辉 1 year ago
parent
commit
2858e64fd1
1 changed files with 15 additions and 10 deletions
  1. 15 10
      process_data.py

+ 15 - 10
process_data.py

@@ -22,7 +22,7 @@ class DataProcessor(object):
         self.client = MysqlClient()
         self.flag = flag
 
-    def generate_train_label(self,item, y_ori_data, cate):
+    def generate_train_label(self, item, y_ori_data, cate):
         """
         生成训练数据,用 np.array矩阵的方式返回,
         :return: x_train, 训练数据, y_train, 训练 label
@@ -48,8 +48,12 @@ class DataProcessor(object):
         ]
         item_features = [item[i] for i in userful_features]
         keywords_textrank, keywords_tf = self.title_processor(video_id)
-        item_features.append(",".join(keywords_textrank))
-        item_features.append(",".join(keywords_tf))
+        if keywords_tf and keywords_textrank:
+            item_features.append(",".join(keywords_textrank))
+            item_features.append(",".join(keywords_tf))
+        else:
+            item_features.append(None)
+            item_features.append(None)
         label_dt = generate_label_date(dt)
         label_obj = y_ori_data.get(label_dt, {}).get(video_id)
         if label_obj:
@@ -65,10 +69,14 @@ class DataProcessor(object):
         :return: tag_list [tag, tag, tag, tag......]
         """
         sql = f"""SELECT video_title from crawler_video where video_id = {video_id};"""
-        title = self.client.select(sql)[0][0]
-        keywords_textrank = jieba.analyse.textrank(title, topK=3)
-        keywords_tfidf = jieba.analyse.extract_tags(title, topK=3)
-        return list(keywords_textrank), list(keywords_tfidf)
+        try:
+            title = self.client.select(sql)
+            keywords_textrank = jieba.analyse.textrank(title, topK=3)
+            keywords_tfidf = jieba.analyse.extract_tags(title, topK=3)
+            return list(keywords_textrank), list(keywords_tfidf)
+        except Exception as e:
+            print(video_id, "\t", e)
+            return [], []
 
     def producer(self):
         """
@@ -105,6 +113,3 @@ class DataProcessor(object):
 if __name__ == "__main__":
     D = DataProcessor(flag="train")
     D.producer()
-
-
-