|
@@ -1,6 +1,7 @@
|
|
|
"""
|
|
|
generate category for given title
|
|
|
"""
|
|
|
+
|
|
|
import time
|
|
|
import concurrent
|
|
|
import traceback
|
|
@@ -60,11 +61,18 @@ class CategoryGenerationTask:
|
|
|
|
|
|
update_rows = thread_db_client.save(
|
|
|
query=update_query,
|
|
|
- params=(self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS),
|
|
|
+ params=(
|
|
|
+ self.const.FAIL_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ ),
|
|
|
)
|
|
|
return update_rows
|
|
|
|
|
|
- def update_title_category(self, thread_db_client: DatabaseConnector, article_id: int, completion: dict):
|
|
|
+ def update_title_category(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
|
|
|
+ ):
|
|
|
|
|
|
try:
|
|
|
category = completion.get(str(article_id))
|
|
@@ -78,8 +86,8 @@ class CategoryGenerationTask:
|
|
|
data={
|
|
|
"article_id": article_id,
|
|
|
"error": str(e),
|
|
|
- "traceback": traceback.format_exc()
|
|
|
- }
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
)
|
|
|
self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
|
|
@@ -96,8 +104,8 @@ class CategoryGenerationTask:
|
|
|
params=(
|
|
|
self.const.INIT_STATUS,
|
|
|
self.const.PROCESSING_STATUS,
|
|
|
- int(time.time()) - self.const.MAX_PROCESSING_TIME
|
|
|
- )
|
|
|
+ int(time.time()) - self.const.MAX_PROCESSING_TIME,
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
return update_rows
|
|
@@ -130,30 +138,27 @@ class CategoryGenerationTask:
|
|
|
article_id = article["id"]
|
|
|
title = article["article_title"]
|
|
|
|
|
|
- id_tuple = (article_id, )
|
|
|
title_batch = [(article_id, title)]
|
|
|
|
|
|
- lock_rows = self.lock_task(thread_db_client, id_tuple)
|
|
|
- if lock_rows:
|
|
|
- prompt = category_generation_from_title(title_batch)
|
|
|
- try:
|
|
|
- completion = fetch_deepseek_completion(
|
|
|
- model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
- )
|
|
|
- self.update_title_category(thread_db_client, article_id, completion)
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+ try:
|
|
|
+ completion = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
+ )
|
|
|
+ self.update_title_category(thread_db_client, article_id, completion)
|
|
|
|
|
|
- except Exception as e:
|
|
|
- log(
|
|
|
- task=self.const.TASK_NAME,
|
|
|
- message="该文章存在敏感词,AI 拒绝返回",
|
|
|
- function="deal_each_article",
|
|
|
- data={
|
|
|
- "article_id": article_id,
|
|
|
- "error": str(e),
|
|
|
- "traceback": traceback.format_exc()
|
|
|
- }
|
|
|
- )
|
|
|
- self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ message="该文章存在敏感词,AI 拒绝返回",
|
|
|
+ function="deal_each_article",
|
|
|
+ data={
|
|
|
+ "article_id": article_id,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
|
|
|
def deal_batch_in_each_thread(self, task_batch: list[dict]):
|
|
|
"""
|
|
@@ -181,8 +186,8 @@ class CategoryGenerationTask:
|
|
|
data={
|
|
|
"article_id": id_tuple,
|
|
|
"error": str(e),
|
|
|
- "traceback": traceback.format_exc()
|
|
|
- }
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
)
|
|
|
for article in tqdm(task_batch):
|
|
|
self.deal_each_article(thread_db_client, article)
|
|
@@ -202,9 +207,13 @@ class CategoryGenerationTask:
|
|
|
fetch_query = f"""
|
|
|
select id, article_title from publish_single_video_source
|
|
|
where category_status = %s and bad_status = %s
|
|
|
- order by score desc limit 20;
|
|
|
+ order by score desc;
|
|
|
"""
|
|
|
- fetch_result = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS))
|
|
|
+ fetch_result = self.db_client.fetch(
|
|
|
+ query=fetch_query,
|
|
|
+ cursor_type=DictCursor,
|
|
|
+ params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
|
|
|
+ )
|
|
|
return fetch_result
|
|
|
|
|
|
def deal(self):
|
|
@@ -213,23 +222,20 @@ class CategoryGenerationTask:
|
|
|
|
|
|
task_list = self.get_task_list()
|
|
|
task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
|
|
|
- for task_batch in task_batch_list:
|
|
|
- self.deal_batch_in_each_thread(task_batch)
|
|
|
-
|
|
|
- # with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
|
|
|
- # futures = [
|
|
|
- # executor.submit(self.deal_in_each_thread, task_batch)
|
|
|
- # for task_batch in task_batch_list
|
|
|
- # ]
|
|
|
- #
|
|
|
- # for _ in tqdm(
|
|
|
- # concurrent.futures.as_completed(futures),
|
|
|
- # total=len(futures),
|
|
|
- # desc="Processing batches",
|
|
|
- # ):
|
|
|
- # pass
|
|
|
-
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- category_generation_task = CategoryGenerationTask()
|
|
|
- category_generation_task.deal()
|
|
|
+
|
|
|
+ ## dev
|
|
|
+ # for task_batch in task_batch_list:
|
|
|
+ # self.deal_batch_in_each_thread(task_batch)
|
|
|
+
|
|
|
+ with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
|
|
|
+ futures = [
|
|
|
+ executor.submit(self.deal_batch_in_each_thread, task_batch)
|
|
|
+ for task_batch in task_batch_list
|
|
|
+ ]
|
|
|
+
|
|
|
+ for _ in tqdm(
|
|
|
+ concurrent.futures.as_completed(futures),
|
|
|
+ total=len(futures),
|
|
|
+ desc="Processing batches",
|
|
|
+ ):
|
|
|
+ pass
|