|
@@ -1,14 +1,19 @@
|
|
|
"""
|
|
|
generate category for given title
|
|
|
"""
|
|
|
+import concurrent
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
from pymysql.cursors import DictCursor
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
-from applications.api import deep_seek_api_by_volcanoengine
|
|
|
+from applications.api.deep_seek_api_official import fetch_deepseek_completion
|
|
|
+from applications.const import CategoryGenerationTaskConst
|
|
|
from applications.db import DatabaseConnector
|
|
|
from applications.utils import yield_batch
|
|
|
from config import long_articles_config
|
|
|
from tasks.ai_tasks.prompts import category_generation_from_title
|
|
|
|
|
|
+
|
|
|
class CategoryGenerationTask:
|
|
|
"""
|
|
|
generate category for given title
|
|
@@ -16,14 +21,58 @@ class CategoryGenerationTask:
|
|
|
def __init__(self):
|
|
|
self.db_client = DatabaseConnector(long_articles_config)
|
|
|
self.db_client.connect()
|
|
|
+ self.const = CategoryGenerationTaskConst()
|
|
|
|
|
|
- def deal_in_each_thread(self, task_batch):
|
|
|
- thread_db_client = DatabaseConnector(long_articles_config)
|
|
|
- thread_db_client.connect()
|
|
|
+ def set_category(self, thread_db_client, article_id, category):
|
|
|
+ """
|
|
|
+ set category for given article
|
|
|
+ """
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category = %s, category_status = %s
|
|
|
+ where id = %s and category_status = %s;
|
|
|
+ """
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS)
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
|
|
|
- task_id_batch = [i['id'] for i in task_batch]
|
|
|
- title_batch = [i['title'] for i in task_batch]
|
|
|
- prompt = category_generation_from_title(title_batch)
|
|
|
+ def lock_task(self, thread_db_client, article_id_tuple):
|
|
|
+ """
|
|
|
+ lock_task
|
|
|
+ """
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category_status = %s
|
|
|
+ where id in %s and category_status = %s;
|
|
|
+ """
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS)
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def deal_in_each_thread(self, task_batch):
|
|
|
+ try:
|
|
|
+ thread_db_client = DatabaseConnector(long_articles_config)
|
|
|
+ thread_db_client.connect()
|
|
|
+ title_batch = [(i['id'], i['article_title']) for i in task_batch]
|
|
|
+ id_tuple = tuple([i['id'] for i in task_batch])
|
|
|
+ # lock task
|
|
|
+ lock_rows = self.lock_task(thread_db_client, id_tuple)
|
|
|
+ if lock_rows:
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+ # print(prompt)
|
|
|
+ completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json')
|
|
|
+ for article in title_batch:
|
|
|
+ article_id = str(article[0])
|
|
|
+ category = completion.get(article_id)
|
|
|
+ self.set_category(thread_db_client, article_id, category)
|
|
|
+ else:
|
|
|
+ return
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
|
|
|
def get_task_list(self):
|
|
|
"""
|
|
@@ -31,17 +80,28 @@ class CategoryGenerationTask:
|
|
|
"""
|
|
|
fetch_query = f"""
|
|
|
select id, article_title from publish_single_video_source
|
|
|
- where category_status = 0
|
|
|
- order by score desc limit 100;
|
|
|
+ where category_status = 0 and bad_status = 0
|
|
|
+ order by score desc;
|
|
|
"""
|
|
|
fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
|
|
|
return fetch_result
|
|
|
|
|
|
def deal(self):
|
|
|
task_list = self.get_task_list()
|
|
|
- task_batch_list = yield_batch(data=task_list, batch_size=20)
|
|
|
- for task_batch in task_batch_list:
|
|
|
- self.deal_in_each_thread(task_batch)
|
|
|
+ task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
|
|
|
+ max_workers = 5
|
|
|
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
+ # 提交所有任务到线程池
|
|
|
+ futures = [executor.submit(self.deal_in_each_thread, task_batch)
|
|
|
+ for task_batch in task_batch_list]
|
|
|
+
|
|
|
+ # 用 tqdm 跟踪任务完成进度
|
|
|
+ for _ in tqdm(
|
|
|
+ concurrent.futures.as_completed(futures),
|
|
|
+ total=len(futures),
|
|
|
+ desc="Processing batches"
|
|
|
+ ):
|
|
|
+ pass # 仅用于更新进度条,不需要结果
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|