category_generation.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """
  2. generate category for given title
  3. """
  4. from pymysql.cursors import DictCursor
  5. from applications.api import deep_seek_api_by_volcanoengine
  6. from applications.db import DatabaseConnector
  7. from applications.utils import yield_batch
  8. from config import long_articles_config
  9. from tasks.ai_tasks.prompts import category_generation_from_title
  10. class CategoryGenerationTask:
  11. """
  12. generate category for given title
  13. """
  14. def __init__(self):
  15. self.db_client = DatabaseConnector(long_articles_config)
  16. self.db_client.connect()
  17. def deal_in_each_thread(self, task_batch):
  18. thread_db_client = DatabaseConnector(long_articles_config)
  19. thread_db_client.connect()
  20. task_id_batch = [i['id'] for i in task_batch]
  21. title_batch = [i['title'] for i in task_batch]
  22. prompt = category_generation_from_title(title_batch)
  23. def get_task_list(self):
  24. """
  25. get task_list from a database
  26. """
  27. fetch_query = f"""
  28. select id, article_title from publish_single_video_source
  29. where category_status = 0
  30. order by score desc limit 100;
  31. """
  32. fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
  33. return fetch_result
  34. def deal(self):
  35. task_list = self.get_task_list()
  36. task_batch_list = yield_batch(data=task_list, batch_size=20)
  37. for task_batch in task_batch_list:
  38. self.deal_in_each_thread(task_batch)
  39. if __name__ == '__main__':
  40. category_generation_task = CategoryGenerationTask()
  41. category_generation_task.deal()