123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- """
- use llm function to recognize the account information
- """
- import json
- from pymysql.cursors import DictCursor
- from tqdm import tqdm
- from threading import local
- import concurrent
- from concurrent.futures import ThreadPoolExecutor
- from applications.api import fetch_deepseek_completion
- from applications.db import DatabaseConnector
- from config import long_articles_config
- from tasks.ai_tasks.prompts import category_generation_for_each_account
- from tasks.ai_tasks.prompts import get_title_match_score_list
- thread_local = local()
- def get_db_client():
- """
- each thread get it's own db client
- """
- if not hasattr(thread_local, "db_client"):
- thread_local.db_client = DatabaseConnector(long_articles_config)
- thread_local.db_client.connect()
- return thread_local.db_client
- def update_task_status(thread_db_client, task_id, ori_status, new_status):
- """
- update task status
- """
- update_query = f"""
- update crawler_candidate_account_pool
- set status = %s
- where id = %s and status = %s;
- """
- thread_db_client.save(update_query, (new_status, task_id, ori_status))
- def update_task_category_status(thread_db_client, task_id, ori_status, new_status):
- """
- update task status
- """
- update_query = f"""
- update crawler_candidate_account_pool
- set category_status = %s
- where id = %s and category_status = %s;
- """
- thread_db_client.save(update_query, (new_status, task_id, ori_status))
- def get_account_score(thread_db_client, account):
- """
- recognize each account
- """
- task_id = account["id"]
- # lock task
- update_task_status(thread_db_client, task_id, 0, 1)
- # process
- title_list = json.loads(account["title_list"])
- if len(title_list) < 15 and account["platform"] == "toutiao":
- # 账号数量不足,直接跳过
- print("bad account, skip")
- update_task_status(thread_db_client, task_id, 1, 11)
- return
- # 标题长度过长,需要过滤
- title_total_length = sum(len(title) for title in title_list)
- avg_title_length = title_total_length / len(title_list)
- if avg_title_length > 45:
- print("title too long, skip")
- update_task_status(thread_db_client, task_id, 1, 14)
- return
- prompt = get_title_match_score_list(title_list)
- response = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt)
- response_score_str = response.strip()
- try:
- score_list = json.loads(response_score_str)
- avg_score = sum(score_list) / len(score_list)
- except Exception as e:
- score_list = []
- avg_score = 0
- if score_list and avg_score:
- update_query = f"""
- update crawler_candidate_account_pool
- set score_list = %s, avg_score = %s, status = %s
- where id = %s and status = %s;
- """
- thread_db_client.save(
- update_query, (json.dumps(score_list), avg_score, 2, task_id, 1)
- )
- else:
- update_task_status(thread_db_client, task_id, 1, 12)
- def get_account_category(thread_db_client, account):
- """
- recognize each account
- """
- task_id = account["id"]
- title_list = json.loads(account["title_list"])
- # lock task
- update_task_category_status(thread_db_client, task_id, 0, 1)
- prompt = category_generation_for_each_account(title_list)
- response = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt)
- print(response)
- response_category = response.strip()
- if response_category:
- update_query = f"""
- update crawler_candidate_account_pool
- set category = %s, category_status = %s
- where id = %s and category_status = %s;
- """
- thread_db_client.save(update_query, (response_category, 2, task_id, 1))
- else:
- update_task_category_status(thread_db_client, task_id, 1, 99)
- def recognize_account_thread(account, task):
- """
- recognize thread
- """
- match task:
- case "score":
- thread_db_client = get_db_client()
- try:
- get_account_score(thread_db_client, account)
- except Exception as e:
- update_task_status(
- thread_db_client=thread_db_client,
- task_id=account["id"],
- ori_status=1,
- new_status=13,
- )
- case "category":
- thread_db_client = get_db_client()
- try:
- get_account_category(thread_db_client, account)
- except Exception as e:
- update_task_category_status(
- thread_db_client=thread_db_client,
- task_id=account["id"],
- ori_status=1,
- new_status=99,
- )
- case "_":
- return
- class CandidateAccountRecognizer:
- INIT_STATUS = 0
- PROCESSING_STATUS = 1
- SUCCESS_STATUS = 2
- FAILED_STATUS = 99
- AVG_SCORE_THRESHOLD = 65
- def __init__(self):
- self.db_client = DatabaseConnector(long_articles_config)
- self.db_client.connect()
- class CandidateAccountQualityScoreRecognizer(CandidateAccountRecognizer):
- def get_task_list(self):
- """
- get account tasks from the database
- """
- fetch_query = f"""
- select id, title_list, platform
- from crawler_candidate_account_pool
- where avg_score is null and status = {self.INIT_STATUS} and title_list is not null;
- """
- fetch_response = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
- return fetch_response
- def deal(self):
- task_list = self.get_task_list()
- with ThreadPoolExecutor(max_workers=8) as executor:
- futures = [
- executor.submit(recognize_account_thread, task, "score")
- for task in task_list
- ]
- for future in tqdm(
- concurrent.futures.as_completed(futures),
- total=len(task_list),
- desc="处理进度",
- ):
- future.result()
- class CandidateAccountCategoryRecognizer(CandidateAccountRecognizer):
- def get_task_list(self):
- fetch_query = f"""
- select id, title_list
- from crawler_candidate_account_pool
- where category_status = %s and avg_score >= %s;
- """
- fetch_response = self.db_client.fetch(
- fetch_query,
- cursor_type=DictCursor,
- params=(self.INIT_STATUS, self.AVG_SCORE_THRESHOLD),
- )
- return fetch_response
- def deal(self):
- task_list = self.get_task_list()
- with ThreadPoolExecutor(max_workers=8) as executor:
- futures = [
- executor.submit(recognize_account_thread, task, "category")
- for task in task_list
- ]
- for future in tqdm(
- concurrent.futures.as_completed(futures),
- total=len(task_list),
- desc="处理进度",
- ):
- future.result()
|