|
|
@@ -7,6 +7,7 @@ from typing import Optional, List, Dict, Tuple
|
|
|
from applications.api import fetch_deepseek_completion
|
|
|
from applications.utils import yield_batch
|
|
|
from applications.tasks.llm_tasks.prompts import extract_article_features
|
|
|
+from applications.tasks.llm_tasks.prompts import extract_article_category
|
|
|
from tqdm.asyncio import tqdm
|
|
|
|
|
|
|
|
|
@@ -623,15 +624,15 @@ class ExtractTitleFeatures(Const):
|
|
|
self.aliyun_log = aliyun_log
|
|
|
self.trace_id = trace_id
|
|
|
|
|
|
- async def get_tasks(self, batch_size=100):
|
|
|
+ async def get_tasks(self, version: int, batch_size=100):
|
|
|
query = """
|
|
|
select id, title
|
|
|
from title_features
|
|
|
- where status = %s
|
|
|
+ where status = %s and version = %s
|
|
|
limit %s;
|
|
|
"""
|
|
|
return await self.pool.async_fetch(
|
|
|
- query=query, params=(self.INIT_STATUS, batch_size)
|
|
|
+ query=query, params=(self.INIT_STATUS, version, batch_size)
|
|
|
)
|
|
|
|
|
|
async def update_status(self, title_id, ori_status, new_status):
|
|
|
@@ -681,21 +682,26 @@ class ExtractTitleFeatures(Const):
|
|
|
),
|
|
|
)
|
|
|
|
|
|
- async def deal(self, data):
|
|
|
- batch_size = data.get("batch_size", 50)
|
|
|
- task_list = await self.get_tasks(batch_size=batch_size)
|
|
|
+ async def set_category_for_each_title(self, title_id, category):
|
|
|
+ query = """
|
|
|
+ UPDATE title_features
|
|
|
+ SET category = %s, status = %s
|
|
|
+ WHERE id = %s and status = %s;
|
|
|
+ """
|
|
|
+ return await self.pool.async_save(
|
|
|
+ query=query,
|
|
|
+ params=(category, self.SUCCESS_STATUS, title_id, self.PROCESSING_STATUS),
|
|
|
+ )
|
|
|
|
|
|
+ async def get_title_features(self, task_list: list):
|
|
|
title_list = [i["title"] for i in task_list]
|
|
|
id_list = [i["id"] for i in task_list]
|
|
|
title_id_map = {i["title"]: i["id"] for i in task_list}
|
|
|
-
|
|
|
prompt = extract_article_features(title_list)
|
|
|
-
|
|
|
# 设置状态为处理中
|
|
|
await self.update_status_batch(
|
|
|
id_list, self.INIT_STATUS, self.PROCESSING_STATUS
|
|
|
)
|
|
|
-
|
|
|
try:
|
|
|
feature_dict = fetch_deepseek_completion(
|
|
|
model="default", prompt=prompt, output_type="json"
|
|
|
@@ -745,12 +751,85 @@ class ExtractTitleFeatures(Const):
|
|
|
title_id = title_id_map[title]
|
|
|
await self.set_feature_for_each_title(title_id, features)
|
|
|
|
|
|
+ async def get_title_category(self, task_list: list):
|
|
|
+ title_list = [i["title"] for i in task_list]
|
|
|
+ id_list = [i["id"] for i in task_list]
|
|
|
+ title_id_map = {i["title"]: i["id"] for i in task_list}
|
|
|
+ generate_category_prompt = extract_article_category(title_list)
|
|
|
+ # 设置状态为处理中
|
|
|
+ await self.update_status_batch(
|
|
|
+ id_list, self.INIT_STATUS, self.PROCESSING_STATUS
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ category_dict = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=generate_category_prompt, output_type="json"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ await self.aliyun_log.log(
|
|
|
+ contents={
|
|
|
+ "task": "extract_title_category",
|
|
|
+ "function": "deal",
|
|
|
+ "message": "fetch deepseek completion failed",
|
|
|
+ "status": "fail",
|
|
|
+ "data": {
|
|
|
+ "error_message": str(e),
|
|
|
+ "error_type": type(e).__name__,
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+ await self.update_status_batch(
|
|
|
+ id_list, self.PROCESSING_STATUS, self.FAIL_STATUS
|
|
|
+ )
|
|
|
+ return
|
|
|
+ if not category_dict:
|
|
|
+ await self.aliyun_log.log(
|
|
|
+ contents={
|
|
|
+ "task": "extract_title_category",
|
|
|
+ "function": "deal",
|
|
|
+ "message": "fetch deepseek completion return empty",
|
|
|
+ "status": "fail",
|
|
|
+ "data": {
|
|
|
+ "error_message": "fetch deepseek completion return empty",
|
|
|
+ "error_type": "EmptyResponseError",
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+ await self.update_status_batch(
|
|
|
+ id_list, self.PROCESSING_STATUS, self.FAIL_STATUS
|
|
|
+ )
|
|
|
+ return
|
|
|
|
|
|
-class OutsideArticleCategoryGeneration(TitleProcess):
|
|
|
- def __init__(self, pool, log_client, trace_id):
|
|
|
- super().__init__(pool, log_client, trace_id)
|
|
|
-
|
|
|
- async def get_title_list(self):
|
|
|
- pass
|
|
|
+ for title in tqdm(title_list):
|
|
|
+ category = category_dict.get(title, {})
|
|
|
+ if not category:
|
|
|
+ continue
|
|
|
|
|
|
+ title_id = title_id_map[title]
|
|
|
+ await self.set_category_for_each_title(title_id, category)
|
|
|
|
|
|
+ async def deal(self, data):
|
|
|
+ batch_size = data.get("batch_size", 50)
|
|
|
+ version = data.get("version", 1)
|
|
|
+ task_list = await self.get_tasks(version=version, batch_size=batch_size)
|
|
|
+
|
|
|
+ match version:
|
|
|
+ case 1:
|
|
|
+ await self.get_title_features(task_list)
|
|
|
+ case 2:
|
|
|
+ await self.get_title_category(task_list)
|
|
|
+ case _:
|
|
|
+ await self.aliyun_log.log(
|
|
|
+ contents={
|
|
|
+ "task": "extract_title_features",
|
|
|
+ "function": "deal",
|
|
|
+ "message": "version not supported",
|
|
|
+ "status": "fail",
|
|
|
+ "data": {
|
|
|
+ "error_message": "version not supported",
|
|
|
+ "error_type": "VersionNotSupportedError",
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|