1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- """
- generate category for given title
- """
- from pymysql.cursors import DictCursor
- from applications.api import deep_seek_api_by_volcanoengine
- from applications.db import DatabaseConnector
- from applications.utils import yield_batch
- from config import long_articles_config
- from tasks.ai_tasks.prompts import category_generation_from_title
- class CategoryGenerationTask:
- """
- generate category for given title
- """
- def __init__(self):
- self.db_client = DatabaseConnector(long_articles_config)
- self.db_client.connect()
- def deal_in_each_thread(self, task_batch):
- thread_db_client = DatabaseConnector(long_articles_config)
- thread_db_client.connect()
- task_id_batch = [i['id'] for i in task_batch]
- title_batch = [i['title'] for i in task_batch]
- prompt = category_generation_from_title(title_batch)
- def get_task_list(self):
- """
- get task_list from a database
- """
- fetch_query = f"""
- select id, article_title from publish_single_video_source
- where category_status = 0
- order by score desc limit 100;
- """
- fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
- return fetch_result
- def deal(self):
- task_list = self.get_task_list()
- task_batch_list = yield_batch(data=task_list, batch_size=20)
- for task_batch in task_batch_list:
- self.deal_in_each_thread(task_batch)
- if __name__ == '__main__':
- category_generation_task = CategoryGenerationTask()
- category_generation_task.deal()
|