sunxy 11 месяцев назад
Родитель
Сommit
df3ba4c999
4 измененных файлов с 90 добавлено и 31 удалено
  1. 2 2
      ai_tag_task.py
  2. 1 1
      config.py
  3. 72 0
      result_save.py
  4. 15 28
      utils.py

+ 2 - 2
ai_tag_task.py

@@ -13,7 +13,7 @@ from whisper_asr import get_whisper_asr
 from gpt_tag import request_gpt
 from config import set_config
 from log import Log
-import mysql_connect
+from result_save import insert_content
 config_ = set_config()
 log_ = Log()
 features = ['videoid', 'title', 'video_path']
@@ -79,7 +79,7 @@ def get_video_ai_tags(video_id, asr_file, video_info):
                     log_message.update(parseRes)
 
                     # 6. 保存结果
-                    mysql_connect.insert_content(parseRes)
+                    insert_content(parseRes)
 
                 except:
                     log_.error(traceback.format_exc())

+ 1 - 1
config.py

@@ -5,7 +5,7 @@ class BaseConfig(object):
     # 每日AI标签视频信息表
     DAILY_VIDEO = {
         'project': 'loghubods',
-        'table': 'video_aigc_tag_yesterday_upload'
+        'table': 'vid_daily_top_not_taged'
     }
 
     # ODPS服务配置

+ 72 - 0
result_save.py

@@ -0,0 +1,72 @@
+import mysql.connector
+import json
+
+# 配置数据库连接参数
+db_config = {
+    'host': 'rm-bp19uc56sud25ag4o.mysql.rds.aliyuncs.com',
+    'database': 'longvideo',
+    'port': 3306,
+    'user': 'wx2016_longvideo',
+    'password': 'wx2016_longvideoP@assword1234',
+}
+
+json_field_names = ['key_words', 'search_keys', 'extra_keys', 'category_list']
+
+normal_field_names = ['tone', 'target_audience',
+                      'target_age', 'target_gender', 'address', 'theme']
+
+
+def insert_content(gpt_res):
+    """ 连接MySQL数据库并插入一行数据 """
+    try:
+        # 连接MySQL数据库
+        conn = mysql.connector.connect(
+            host=db_config['host'],
+            database=db_config['database'],
+            user=db_config['user'],
+            password=db_config['password'],
+        )
+
+        if conn.is_connected():
+            print('成功连接到数据库')
+            cursor = conn.cursor()
+
+            insert_batch = []
+
+            # 插入数据的SQL语句
+            sql = """
+            INSERT INTO video_content_mapping (video_id, tag, tag_type)
+            VALUES (%s, %s, %s)
+            """
+
+            video_id = gpt_res.get('video_id', '')
+            for field_name in json_field_names:
+                # 获取对应的JSON字符串
+                tags = gpt_res.get(field_name, '')
+                # 判断是否是json字符串
+                if not tags or not isinstance(tags, list):
+                    continue
+                # 构建批量插入的参数
+                for tag in tags:
+                    insert_batch.append((video_id, tag, field_name))
+
+            for field_name in normal_field_names:
+                # 获取对应的字段值
+                value = gpt_res.get(field_name, '')
+                # 构建批量插入的参数
+                insert_batch.append((video_id, value, field_name))
+
+            # 执行批量插入操作
+            cursor.executemany(sql, insert_batch)
+            print(f"Inserting records {len(insert_batch)} rows...")
+            insert_batch.clear()
+
+            # 提交事务
+            conn.commit()
+
+            # 关闭游标和连接
+            cursor.close()
+            conn.close()
+            print('数据库连接已关闭')
+    except mysql.connector.Error as e:
+        print('数据库连接或操作出错:', e)

+ 15 - 28
utils.py

@@ -11,8 +11,7 @@ log_ = Log()
 config_ = set_config()
 
 
-def get_data_from_odps(date, project, table, connect_timeout=3000, read_timeout=500000,
-                       pool_maxsize=1000, pool_connections=1000):
+def get_data_from_odps(date, project, table):
     """
     从odps获取数据
     :param date: 日期 type-string '%Y%m%d'
@@ -28,11 +27,7 @@ def get_data_from_odps(date, project, table, connect_timeout=3000, read_timeout=
         access_id=config_.ODPS_CONFIG['ACCESSID'],
         secret_access_key=config_.ODPS_CONFIG['ACCESSKEY'],
         project=project,
-        endpoint=config_.ODPS_CONFIG['ENDPOINT'],
-        connect_timeout=connect_timeout,
-        read_timeout=read_timeout,
-        pool_maxsize=pool_maxsize,
-        pool_connections=pool_connections
+        endpoint=config_.ODPS_CONFIG['ENDPOINT']
     )
     records = odps.read_table(name=table, partition='dt=%s' % date)
     return records
@@ -51,8 +46,7 @@ def get_feature_data(project, table, dt, features):
     return feature_df
 
 
-def check_table_partition_exits(date, project, table, connect_timeout=3000, read_timeout=500000,
-                                pool_maxsize=1000, pool_connections=1000):
+def check_table_partition_exits(date, project, table):
     """
     判断表中是否存在这个分区
     :param date: 日期 type-string '%Y%m%d'
@@ -68,11 +62,7 @@ def check_table_partition_exits(date, project, table, connect_timeout=3000, read
         access_id=config_.ODPS_CONFIG['ACCESSID'],
         secret_access_key=config_.ODPS_CONFIG['ACCESSKEY'],
         project=project,
-        endpoint=config_.ODPS_CONFIG['ENDPOINT'],
-        connect_timeout=connect_timeout,
-        read_timeout=read_timeout,
-        pool_maxsize=pool_maxsize,
-        pool_connections=pool_connections
+        endpoint=config_.ODPS_CONFIG['ENDPOINT']
     )
     t = odps.get_table(name=table)
     return t.exist_partition(partition_spec=f'dt={date}')
@@ -84,15 +74,12 @@ def data_check(project, table, dt):
         access_id=config_.ODPS_CONFIG['ACCESSID'],
         secret_access_key=config_.ODPS_CONFIG['ACCESSKEY'],
         project=project,
-        endpoint=config_.ODPS_CONFIG['ENDPOINT'],
-        connect_timeout=3000,
-        read_timeout=500000,
-        pool_maxsize=1000,
-        pool_connections=1000
+        endpoint=config_.ODPS_CONFIG['ENDPOINT']
     )
 
     try:
-        check_res = check_table_partition_exits(date=dt, project=project, table=table)
+        check_res = check_table_partition_exits(
+            date=dt, project=project, table=table)
         if check_res:
             sql = f'select * from {project}.{table} where dt = {dt}'
             with odps.execute_sql(sql=sql).open_reader() as reader:
@@ -113,7 +100,8 @@ def request_post(request_url, headers, request_data):
     :return: res_data json格式
     """
     try:
-        response = requests.post(url=request_url, json=request_data, headers=headers)
+        response = requests.post(
+            url=request_url, json=request_data, headers=headers)
         # print(response)
         if response.status_code == 200:
             res_data = json.loads(response.text)
@@ -121,7 +109,8 @@ def request_post(request_url, headers, request_data):
         else:
             return None
     except Exception as e:
-        log_.error('url: {}, exception: {}, traceback: {}'.format(request_url, e, traceback.format_exc()))
+        log_.error('url: {}, exception: {}, traceback: {}'.format(
+            request_url, e, traceback.format_exc()))
         return None
 
 
@@ -134,14 +123,16 @@ def request_get(request_url, headers, params=None):
     :return: res_data json格式
     """
     try:
-        response = requests.get(url=request_url, headers=headers, params=params)
+        response = requests.get(
+            url=request_url, headers=headers, params=params)
         if response.status_code == 200:
             res_data = json.loads(response.text)
             return res_data
         else:
             return None
     except Exception as e:
-        log_.error('url: {}, exception: {}, traceback: {}'.format(request_url, e, traceback.format_exc()))
+        log_.error('url: {}, exception: {}, traceback: {}'.format(
+            request_url, e, traceback.format_exc()))
         return None
 
 
@@ -173,9 +164,6 @@ def asr_validity_discrimination(text):
     return True
 
 
-
-
-
 if __name__ == '__main__':
     text = """现场和电视机前的观众朋友,大家晚上好。
 这里是非常说明的访谈现场,
@@ -913,4 +901,3 @@ Haha哈哈那个。
 他还是。"""
     res = asr_validity_discrimination(text=text)
     print(res)
-