""" use llm function to recognize the account information """ import json from pymysql.cursors import DictCursor from tqdm import tqdm from threading import local import concurrent from concurrent.futures import ThreadPoolExecutor from applications.api import fetch_deepseek_completion from applications.db import DatabaseConnector from config import long_articles_config from tasks.ai_tasks.prompts import category_generation_for_each_account from tasks.ai_tasks.prompts import get_title_match_score_list thread_local = local() def get_db_client(): """ each thread get it's own db client """ if not hasattr(thread_local, "db_client"): thread_local.db_client = DatabaseConnector(long_articles_config) thread_local.db_client.connect() return thread_local.db_client def update_task_status(thread_db_client, task_id, ori_status, new_status): """ update task status """ update_query = f""" update crawler_candidate_account_pool set status = %s where id = %s and status = %s; """ thread_db_client.save(update_query, (new_status, task_id, ori_status)) def update_task_category_status(thread_db_client, task_id, ori_status, new_status): """ update task status """ update_query = f""" update crawler_candidate_account_pool set category_status = %s where id = %s and category_status = %s; """ thread_db_client.save(update_query, (new_status, task_id, ori_status)) def get_account_score(thread_db_client, account): """ recognize each account """ task_id = account["id"] # lock task update_task_status(thread_db_client, task_id, 0, 1) # process title_list = json.loads(account["title_list"]) if len(title_list) < 15 and account["platform"] == "toutiao": # 账号数量不足,直接跳过 print("bad account, skip") update_task_status(thread_db_client, task_id, 1, 11) return # 标题长度过长,需要过滤 title_total_length = sum(len(title) for title in title_list) avg_title_length = title_total_length / len(title_list) if avg_title_length > 45: print("title too long, skip") update_task_status(thread_db_client, task_id, 1, 14) return prompt = get_title_match_score_list(title_list) response = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt) response_score_str = response.strip() try: score_list = json.loads(response_score_str) avg_score = sum(score_list) / len(score_list) except Exception as e: score_list = [] avg_score = 0 if score_list and avg_score: update_query = f""" update crawler_candidate_account_pool set score_list = %s, avg_score = %s, status = %s where id = %s and status = %s; """ thread_db_client.save( update_query, (json.dumps(score_list), avg_score, 2, task_id, 1) ) else: update_task_status(thread_db_client, task_id, 1, 12) def get_account_category(thread_db_client, account): """ recognize each account """ task_id = account["id"] title_list = json.loads(account["title_list"]) # lock task update_task_category_status(thread_db_client, task_id, 0, 1) prompt = category_generation_for_each_account(title_list) response = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt) print(response) response_category = response.strip() if response_category: update_query = f""" update crawler_candidate_account_pool set category = %s, category_status = %s where id = %s and category_status = %s; """ thread_db_client.save(update_query, (response_category, 2, task_id, 1)) else: update_task_category_status(thread_db_client, task_id, 1, 99) def recognize_account_thread(account, task): """ recognize thread """ match task: case "score": thread_db_client = get_db_client() try: get_account_score(thread_db_client, account) except Exception as e: update_task_status( thread_db_client=thread_db_client, task_id=account["id"], ori_status=1, new_status=13, ) case "category": thread_db_client = get_db_client() try: get_account_category(thread_db_client, account) except Exception as e: update_task_category_status( thread_db_client=thread_db_client, task_id=account["id"], ori_status=1, new_status=99, ) case "_": return class CandidateAccountRecognizer: INIT_STATUS = 0 PROCESSING_STATUS = 1 SUCCESS_STATUS = 2 FAILED_STATUS = 99 AVG_SCORE_THRESHOLD = 65 def __init__(self): self.db_client = DatabaseConnector(long_articles_config) self.db_client.connect() class CandidateAccountQualityScoreRecognizer(CandidateAccountRecognizer): def get_task_list(self): """ get account tasks from the database """ fetch_query = f""" select id, title_list, platform from crawler_candidate_account_pool where avg_score is null and status = {self.INIT_STATUS} and title_list is not null; """ fetch_response = self.db_client.fetch(fetch_query, cursor_type=DictCursor) return fetch_response def deal(self): task_list = self.get_task_list() with ThreadPoolExecutor(max_workers=8) as executor: futures = [ executor.submit(recognize_account_thread, task, "score") for task in task_list ] for future in tqdm( concurrent.futures.as_completed(futures), total=len(task_list), desc="处理进度", ): future.result() class CandidateAccountCategoryRecognizer(CandidateAccountRecognizer): def get_task_list(self): fetch_query = f""" select id, title_list from crawler_candidate_account_pool where category_status = %s and avg_score >= %s; """ fetch_response = self.db_client.fetch( fetch_query, cursor_type=DictCursor, params=(self.INIT_STATUS, self.AVG_SCORE_THRESHOLD), ) return fetch_response def deal(self): task_list = self.get_task_list() with ThreadPoolExecutor(max_workers=8) as executor: futures = [ executor.submit(recognize_account_thread, task, "category") for task in task_list ] for future in tqdm( concurrent.futures.as_completed(futures), total=len(task_list), desc="处理进度", ): future.result()