Ver Fonte

ai_tag_task

sunxy há 1 ano atrás
pai
commit
83b4ac3130
4 ficheiros alterados com 114 adições e 23 exclusões
  1. 70 14
      ai_tag_task.py
  2. 1 1
      config.py
  3. 12 8
      gpt_tag.py
  4. 31 0
      moon_shoot_api.py

+ 70 - 14
ai_tag_task.py

@@ -6,6 +6,7 @@ import time
 import traceback
 import requests
 import multiprocessing
+import ODPSQueryUtil
 from threading import Timer
 from utils import data_check, get_feature_data, asr_validity_discrimination
 from whisper_asr import get_whisper_asr
@@ -67,7 +68,8 @@ def get_video_ai_tags(video_id, asr_file, video_info):
                         try:
                             for item in json.loads(gpt_res2):
                                 if item['confidence'] > 0.5 and item['category'] in config_.TAGS_NEW:
-                                    confidence_up_list.append(f"AI标签-{item['category']}")
+                                    confidence_up_list.append(
+                                        f"AI标签-{item['category']}")
                         except:
                             pass
                         confidence_up = ','.join(confidence_up_list)
@@ -78,7 +80,8 @@ def get_video_ai_tags(video_id, asr_file, video_info):
                                                      json={'videoId': int(video_id), 'tagNames': confidence_up})
                             res_data = json.loads(response.text)
                             if res_data['code'] != 0:
-                                log_.error({'videoId': video_id, 'msg': 'add video ai tags fail!'})
+                                log_.error(
+                                    {'videoId': video_id, 'msg': 'add video ai tags fail!'})
                 except:
                     pass
         else:
@@ -99,7 +102,8 @@ def process(video_id, video_info, download_folder):
             video_type = filename.split('.')[-1]
             if video_type in ['mp4', 'm3u8']:
                 video_file = os.path.join(video_folder, filename)
-                get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
+                get_video_ai_tags(
+                    video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
                 # 将处理过的视频进行删除
                 shutil.rmtree(os.path.join(download_folder, video_id))
             else:
@@ -108,12 +112,15 @@ def process(video_id, video_info, download_folder):
 
 def ai_tags(project, table, dt):
     # 获取特征数据
-    feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
+    feature_df = get_feature_data(
+        project=project, table=table, dt=dt, features=features)
     video_id_list = feature_df['videoid'].to_list()
     video_info = {}
     for video_id in video_id_list:
-        title = feature_df[feature_df['videoid'] == video_id]['title'].values[0]
-        video_path = feature_df[feature_df['videoid'] == video_id]['video_path'].values[0]
+        title = feature_df[feature_df['videoid']
+                           == video_id]['title'].values[0]
+        video_path = feature_df[feature_df['videoid']
+                                == video_id]['video_path'].values[0]
         if title is None:
             continue
         title = title.strip()
@@ -151,7 +158,8 @@ def ai_tags(project, table, dt):
                     video_type = filename.split('.')[-1]
                     if video_type in ['mp4', 'm3u8']:
                         video_file = os.path.join(video_folder, filename)
-                        get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
+                        get_video_ai_tags(
+                            video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
                         # 将处理过的视频进行删除
                         shutil.rmtree(os.path.join(download_folder, video_id))
                     else:
@@ -160,12 +168,15 @@ def ai_tags(project, table, dt):
 
 def ai_tags_new(project, table, dt):
     # 获取特征数据
-    feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
+    feature_df = get_feature_data(
+        project=project, table=table, dt=dt, features=features)
     video_id_list = feature_df['videoid'].to_list()
     video_info = {}
     for video_id in video_id_list:
-        title = feature_df[feature_df['videoid'] == video_id]['title'].values[0]
-        video_path = feature_df[feature_df['videoid'] == video_id]['video_path'].values[0]
+        title = feature_df[feature_df['videoid']
+                           == video_id]['title'].values[0]
+        video_path = feature_df[feature_df['videoid']
+                                == video_id]['video_path'].values[0]
         if title is None:
             continue
         title = title.strip()
@@ -191,7 +202,8 @@ def ai_tags_new(project, table, dt):
             if video_info.get(video_id, None) is None:
                 os.remove(asr_file)
             else:
-                get_video_ai_tags(video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
+                get_video_ai_tags(
+                    video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
                 os.remove(asr_file)
 
 
@@ -201,7 +213,8 @@ def timer_check():
         table = config_.DAILY_VIDEO['table']
         now_date = datetime.datetime.today()
         print(f"now_date: {datetime.datetime.strftime(now_date, '%Y%m%d')}")
-        dt = datetime.datetime.strftime(now_date-datetime.timedelta(days=1), '%Y%m%d')
+        dt = datetime.datetime.strftime(
+            now_date-datetime.timedelta(days=1), '%Y%m%d')
         # 查看数据是否已准备好
         data_count = data_check(project=project, table=table, dt=dt)
         if data_count > 0:
@@ -219,8 +232,51 @@ def timer_check():
             # 数据没准备好,1分钟后重新检查
             Timer(60, timer_check).start()
     except Exception as e:
-        print(f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
+        print(
+            f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
 
 
 if __name__ == '__main__':
-    timer_check()
+    # timer_check()
+    size = 500
+    for i in range(0, 500, size):
+        print(f"query_videos start i = {i} ...")
+        records = ODPSQueryUtil.query_videos(i, size)
+        if records is None or len(records) == 0:
+            continue
+        print(f"Got {len(records)} records")
+
+        video_info = {}
+        # 遍历 records,将每个视频的信息添加到字典中
+        for record in records:
+            # 将 video_id 从字符串转换为整数,这里假设 video_id 格式总是 "vid" 后跟数字
+            video_id = int(record['videoid'])
+            title = record['title']
+            video_path = record['video_path']
+
+            # 使用 video_id 作为键,其他信息作为值
+            video_info[video_id] = {'title': title, 'video_path': video_path}
+
+        # 打印结果查看
+        print(video_info)
+
+    asr_folder = 'asr_res'
+    retry = 0
+    while retry < 30:
+        asr_file_list = os.listdir(asr_folder)
+        if len(asr_file_list) < 1:
+            retry += 1
+            time.sleep(60)
+            continue
+        retry = 0
+        for asr_filename in asr_file_list:
+            video_id = int(asr_filename[:-4])
+            if video_id not in video_info:
+                continue
+            asr_file = os.path.join(asr_folder, asr_filename)
+            if video_info.get(video_id, None) is None:
+                os.remove(asr_file)
+            else:
+                get_video_ai_tags(
+                    video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
+                os.remove(asr_file)

+ 1 - 1
config.py

@@ -101,7 +101,7 @@ class BaseConfig(object):
 仅以json格式返回,key为summary, keywords。分别代表概要,关键词。
 -----------------------------
 """,
-            'prompt7': f"""请根据以下的视频信息对其进行分类。类别为其中的一个:【{' '.join(TAGS_NEW)}】。
+            'prompt7': f"""```标题:[title]\n概况:[summary]\n```\n请根据上述的视频信息对其进行分类。类别为其中的一个:【{' '.join(TAGS_NEW)}】。
 仅以json array格式返回,{json_format},key为category与confidence,分别代表类别与分类置信度。给出top 3的分类结果。
 -----------------------------
 """,

+ 12 - 8
gpt_tag.py

@@ -4,6 +4,7 @@ import requests
 import traceback
 from config import set_config
 from log import Log
+from moon_shoot_api import MoonShotHandle
 
 config_ = set_config()
 log_ = Log()
@@ -30,11 +31,13 @@ def get_tag(prompt):
                     },
                 ],
             }
-            response = requests.post(url=config_.GPT_HOST, headers=headers, json=json_data, proxies=proxies)
+            response = requests.post(
+                url=config_.GPT_HOST, headers=headers, json=json_data, proxies=proxies)
             print(response.json())
             print(response.json()['choices'][0]['message']['content'])
             print('\n')
-            result_content = response.json()['choices'][0]['message']['content']
+            result_content = response.json(
+            )['choices'][0]['message']['content']
             return result_content
         except Exception as e:
             print(e)
@@ -51,18 +54,19 @@ def request_gpt(prompt):
     while retry_count < config_.RETRY_MAX_COUNT:
         retry_count += 1
         try:
+            result_content = MoonShotHandle.chat_with_chatgpt(prompt)
             # response = requests.post(url=config_.GPT_URL, json={'content': prompt, 'auth': config_.GPT_OPENAI_API_KEY})
-            response = requests.post(url=config_.GPT_URL, json={'content': prompt})
+            # response = requests.post(url=config_.GPT_URL, json={'content': prompt})
             # print(response.json())
             # print(response.json()['choices'][0]['message']['content'])
             # print('\n')
             # result_content = response.json()['choices'][0]['message']['content']
             # log_.info(f"response.text: {response.text}")
-            res_data = json.loads(response.text)
-            if res_data['code'] != 0:
-                time.sleep(10)
-                continue
-            result_content = res_data['data']['choices'][0]['message']['content']
+            # res_data = json.loads(response.text)
+            # if res_data['code'] != 0:
+            #     time.sleep(10)
+            #     continue
+            # result_content = res_data['data']['choices'][0]['message']['content']
             return result_content
         except Exception:
             time.sleep(10)

+ 31 - 0
moon_shoot_api.py

@@ -0,0 +1,31 @@
+from openai import OpenAI
+
+
+class MoonShotHandle():
+    def __init__(self, api_key=None, api_base=None):
+        self.OPENAI_API_KEY = 'sk-tz1VaKqksTzk0F8HxlU4YVGwj7oa1g0c0puGNUZrdn9MDtzm'
+        self.model = "moonshot-v1-8k"
+
+    def chat(self, question):
+        return self.chat_with_chatgpt(question)
+
+    def chat_with_chatgpt(self, prompt):
+        client = OpenAI(
+            api_key=self.OPENAI_API_KEY,
+            base_url="https://api.moonshot.cn/v1",
+        )
+        chat_completion = client.chat.completions.create(
+            messages=[
+                {
+                    "role": "user",
+                    "content": prompt,
+                }
+            ],
+            model=self.model,
+        )
+        response = chat_completion.choices[0].message.content
+        return response
+
+
+# res = MoonShotHandle().chat("请问你是谁?")
+# print(res)