main.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import json
  2. import time
  3. import traceback
  4. from multiprocessing import Process, cpu_count
  5. from typing import List, Dict
  6. import asyncio
  7. from application.config.common import LoggerManager
  8. from utils.trace_utils import generate_trace_id
  9. from application.config.common import get_consumer, ack_message
  10. from application.functions import MysqlService
  11. from application.spiders.spider_registry import get_spider_class, SPIDER_CLASS_MAP
  12. # ------------------------------- Topic 协程处理核心 -------------------------------
  13. async def async_handle_topic(topic: str):
  14. consumer = get_consumer(topic_name=topic, group_id=topic)
  15. logger = LoggerManager.get_logger(topic, "worker")
  16. aliyun_logger = LoggerManager.get_aliyun_logger(topic, "worker")
  17. while True:
  18. try:
  19. messages = consumer.consume_message(wait_seconds=10, batch_size=1)
  20. if not messages:
  21. await asyncio.sleep(1)
  22. continue
  23. for message in messages:
  24. trace_id = generate_trace_id()
  25. try:
  26. payload = json.loads(message.message_body)
  27. platform = payload["platform"]
  28. mode = payload["mode"]
  29. task_id = payload["id"]
  30. mysql_service = MysqlService(platform, mode, task_id)
  31. user_list = mysql_service.get_user_list()
  32. rule_dict = mysql_service.get_rule_dict()
  33. CrawlerClass = get_spider_class(topic)
  34. crawler = CrawlerClass(
  35. rule_dict=rule_dict,
  36. user_list=user_list,
  37. trace_id=trace_id
  38. )
  39. await crawler.run()
  40. ack_message(mode, platform, message, consumer, trace_id=trace_id)
  41. aliyun_logger.logging(code="1000", message="任务成功完成并确认消息", trace_id=trace_id)
  42. except Exception as e:
  43. aliyun_logger.logging(
  44. code="9001",
  45. message=f"处理消息失败: {e}\n{traceback.format_exc()}",
  46. trace_id=trace_id,
  47. data=message.message_body,
  48. )
  49. except Exception as err:
  50. logger.error(f"[{topic}] 消费失败: {err}\n{traceback.format_exc()}")
  51. await asyncio.sleep(5)
  52. async def run_all_topics(topics: List[str]):
  53. """
  54. 启动当前进程内所有 topic 的协程任务。
  55. """
  56. tasks = [asyncio.create_task(async_handle_topic(topic)) for topic in topics]
  57. await asyncio.gather(*tasks)
  58. def handle_topic_group(topics: List[str]):
  59. """
  60. 子进程入口:运行一个事件循环,处理当前分组内的所有 topics。
  61. """
  62. asyncio.run(run_all_topics(topics))
  63. # ------------------------------- 主调度部分 -------------------------------
  64. def split_topics(topics: List[str], num_groups: int) -> List[List[str]]:
  65. """
  66. 将所有 topic 平均分配为 num_groups 组(用于多个进程)。
  67. """
  68. return [topics[i::num_groups] for i in range(num_groups)]
  69. def start_worker_process(group_id: int, topic_group: List[str], process_map: Dict[int, Process]):
  70. """
  71. 启动一个新的子进程来处理一组 topic。
  72. """
  73. p = Process(target=handle_topic_group, args=(topic_group,), name=f"Worker-{group_id}")
  74. p.start()
  75. process_map[group_id] = p
  76. print(f"[主进程] 启动进程 PID={p.pid} 处理 topics={topic_group}")
  77. def main():
  78. # 获取所有已注册的爬虫 topic 列表
  79. topic_list = list(SPIDER_CLASS_MAP.keys())
  80. print(f"监听 Topics: {topic_list}")
  81. # 使用 CPU 核心数决定进程数
  82. num_cpus = cpu_count()
  83. topic_groups = split_topics(topic_list, num_cpus)
  84. print(f"CPU 核心数: {num_cpus}, 启动进程数: {len(topic_groups)}")
  85. # 启动子进程
  86. process_map: Dict[int, Process] = {}
  87. for group_id, topic_group in enumerate(topic_groups):
  88. start_worker_process(group_id, topic_group, process_map)
  89. # 持续监控子进程状态,异常退出自动重启
  90. try:
  91. while True:
  92. time.sleep(5)
  93. for group_id, p in list(process_map.items()):
  94. if not p.is_alive():
  95. print(f"[监控] 进程 {p.name} PID={p.pid} 已崩溃,尝试重启中...")
  96. time.sleep(2)
  97. start_worker_process(group_id, topic_groups[group_id], process_map)
  98. except KeyboardInterrupt:
  99. print("接收到退出信号,终止所有进程...")
  100. for p in process_map.values():
  101. p.terminate()
  102. for p in process_map.values():
  103. p.join()
  104. if __name__ == '__main__':
  105. main()