zhangliang hace 2 semanas
padre
commit
4715ef3c5e

+ 21 - 0
application/common/log/logger_manager.py

@@ -0,0 +1,21 @@
+from application.common.log import Local, AliyunLogger
+
+class LoggerManager:
+    _local_loggers = {}
+    _aliyun_loggers = {}
+
+    @staticmethod
+    def get_logger(platform, mode, log_to_console=True):
+        key = f"{platform}_{mode}"
+        if key not in LoggerManager._local_loggers:
+            LoggerManager._local_loggers[key] = Local.init_logger(
+                platform=platform, mode=mode, log_to_console=log_to_console
+            )
+        return LoggerManager._local_loggers[key]
+
+    @staticmethod
+    def get_aliyun_logger(platform, mode):
+        key = f"{platform}_{mode}"
+        if key not in LoggerManager._aliyun_loggers:
+            LoggerManager._aliyun_loggers[key] = AliyunLogger(platform=platform, mode=mode)
+        return LoggerManager._aliyun_loggers[key]

+ 1 - 1
application/functions/mysql_service.py

@@ -5,7 +5,7 @@ from application.common import MysqlHelper, AliyunLogger,Local
 
 
 class MysqlService:
-    def __init__(self, task_id, mode, platform):
+    def __init__(self,platform, mode, task_id):
         self.env = "prod"
         self.task_id = task_id
         self.mode = mode

+ 4 - 0
configs/spiders_config.yaml

@@ -27,6 +27,10 @@ bszf_recommend_prod:
       video_url: "$.video_url"
       out_video_id: "$.nid"
 
+
+
+
+
 xngtjl_recommend_prod:
   platform: xiaoniangaotuijianliu
   mode: recommend

+ 10 - 9
crawler_worker/universal_crawler.py

@@ -10,6 +10,8 @@ import cv2
 from datetime import datetime
 from typing import Dict, Any, List, Optional, Union
 from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type, RetryCallState
+
+from application.common.log.logger_manager import LoggerManager
 from utils.extractors import safe_extract, extract_multiple
 
 # 添加公共模块路径
@@ -20,12 +22,9 @@ from application.items import VideoItem
 from application.pipeline import PiaoQuanPipeline
 from application.common.messageQueue import MQ
 from application.common.log import AliyunLogger
-# from application.common.mysql import MysqlHelper
-from configs.messages import MESSAGES
-from configs import codes
-from utils.config_loader import ConfigLoader
 from application.common.log import Local
 from configs.config import base_url
+from application.functions.mysql_service import MysqlService
 
 
 def before_send_log(retry_state: RetryCallState) -> None:
@@ -59,8 +58,9 @@ class UniversalCrawler:
         self.trace_id = trace_id
         self.env = env
         self.config = platform_config
-        self.aliyun_log = AliyunLogger(platform=self.platform, mode=self.mode)
-        self.logger = Local.init_logger(platform=self.platform, mode=self.mode, log_level="INFO", log_to_console=True)
+        # 初始化日志
+        self.logger = LoggerManager.get_logger(platform=self.platform, mode=self.mode)
+        self.aliyun_logr = LoggerManager.get_aliyun_logger(platform=self.platform, mode=self.mode)
         self.mq = MQ(topic_name=f"topic_crawler_etl_{env}")
 
         self.has_enough_videos = False
@@ -258,9 +258,10 @@ class UniversalCrawler:
         return True
 
     def _get_video_count_from_db(self) -> int:
-        """从数据库获取视频数量(示例方法,需根据实际业务实现)"""
-        # 实际实现中应该查询数据库
-        return 0  # 占位符
+        """从数据库获取视频数量"""
+        mysql = MysqlService(self.platform,self.mode,self.trace_id)
+        video_count = mysql.get_today_videos()
+        return video_count
 
     def _process_video(self, video_data: Dict) -> bool:
         """

+ 70 - 51
main.py

@@ -1,36 +1,35 @@
 import importlib
-import threading
-import traceback
 import json
 import time
-import uuid
+import traceback
+from concurrent.futures import ThreadPoolExecutor, as_completed, Future
+from typing import Dict
 
-from application.common import AliyunLogger, get_consumer, ack_message
-from application.common.log import Local
+from application.common.logger_manager import LoggerManager
+from application.common.trace_utils import generate_trace_id
+from application.common import get_consumer, ack_message
 from crawler_worker.universal_crawler import UniversalCrawler
 from application.config import TopicGroup
 from application.functions.mysql_service import MysqlService
 from utils.config_loader import ConfigLoader
 
 
-def generate_trace_id():
-    return f"{uuid.uuid4().hex}{int(time.time() * 1000)}"
-
-
-def import_custom_class(class_path):
+def import_custom_class(class_path: str):
     """
-    动态导入模块中的类,如 crawler_worker.universal_crawler.UniversalCrawler
+    动态导入爬虫类,例如 crawler_worker.custom.xx.Crawler
     """
     module_path, class_name = class_path.rsplit(".", 1)
-    print(module_path, class_name)
     module = importlib.import_module(module_path)
     return getattr(module, class_name)
 
 
 def handle_message(topic: str, mode: str):
+    """
+    单线程消费指定 topic 消息的核心逻辑,会持续轮询 MQ
+    """
     consumer = get_consumer(topic_name=topic, group_id=topic)
-    logger = AliyunLogger(platform=topic, mode=mode)
     platform_config = ConfigLoader().get_platform_config(topic)
+
     while True:
         try:
             messages = consumer.consume_message(wait_seconds=10, batch_size=1)
@@ -46,65 +45,85 @@ def handle_message(topic: str, mode: str):
                     platform = payload["platform"]
                     mode = payload["mode"]
                     task_id = payload["id"]
-                    mysql_service = MysqlService(task_id, mode, platform)
-                    logger.logging(
-                        1001,
-                        "开始一轮抓取",
-                        data=payload,
-                        trace_id=trace_id
-                    )
-                    Local.init_logger(platform, mode).info(f"[trace_id={trace_id}] 收到任务: {body}")
 
-                    # 加载 user_list 与 rule_dict
+                    # 初始化日志
+                    logger = LoggerManager.get_logger(platform, mode)
+                    aliyun_logger = LoggerManager.get_aliyun_logger(platform, mode)
+                    logger.info(f"[trace_id={trace_id}] 收到任务: {body}")
+
+                    # 初始化配置、用户与规则
+                    mysql_service = MysqlService(platform, mode, task_id)
                     user_list = mysql_service.get_user_list()
                     rule_dict = mysql_service.get_rule_dict()
-                    custom_class = platform_config.get("custom_class")  # 自定义类
-                    try:
-                        if custom_class:
-                            CrawlerClass = import_custom_class(custom_class)
-                        else:
-                            CrawlerClass = UniversalCrawler
-
-                        crawler = CrawlerClass(
-                            platform_config=platform_config,  # 把整段配置传进去
-                            rule_dict=rule_dict,
-                            user_list=user_list,
-                            trace_id=trace_id
-                        )
-                        crawler.run()
-                    except Exception as e:
-                        print(f"[{topic}] 爬虫运行异常: {e}")
-
-                    # 执行成功后 ack
+                    custom_class = platform_config.get("custom_class")
+
+                    # 实例化爬虫类
+                    CrawlerClass = import_custom_class(custom_class) if custom_class else UniversalCrawler
+                    crawler = CrawlerClass(
+                        platform_config=platform_config,
+                        rule_dict=rule_dict,
+                        user_list=user_list,
+                        trace_id=trace_id
+                    )
+                    crawler.run()
+
+                    # 爬虫成功,确认消息
                     ack_message(mode, platform, message, consumer, trace_id=trace_id)
-                    logger.logging(code="1000", message="任务成功完成并确认消息", trace_id=trace_id)
+                    aliyun_logger.logging(code="1000", message="任务成功完成并确认消息", trace_id=trace_id)
 
                 except Exception as e:
-                    logger.logging(
+                    aliyun_logger.logging(
                         code="9001",
                         message=f"处理消息失败(未确认 ack): {e}\n{traceback.format_exc()}",
                         trace_id=trace_id,
                         data=body,
                     )
-                    # 不 ack,等待下次重试
         except Exception as err:
-            logger.logging(code="9002", message=f"消费失败: {err}\n{traceback.format_exc()}")
+            logger = LoggerManager.get_logger(topic, mode)
+            logger.error(f"[{topic}] 消费失败: {err}\n{traceback.format_exc()}")
+            time.sleep(5)  # 防止崩溃后频繁拉起
+
+
+def monitor_and_restart(future: Future, topic: str, mode: str, pool: ThreadPoolExecutor, thread_map: Dict[str, Future]):
+    """
+    线程崩溃恢复监控器:线程挂掉后自动重启
+    """
+    try:
+        future.result()  # 获取结果,触发异常
+    except Exception as e:
+        print(f"[监控] 线程 {topic} 异常退出:{e},5秒后尝试重启")
+        time.sleep(5)
+        # 重新提交任务
+        new_future = pool.submit(handle_message, topic, mode)
+        thread_map[topic] = new_future
+        # 注册新的回调
+        new_future.add_done_callback(lambda f: monitor_and_restart(f, topic, mode, pool, thread_map))
 
 
 def main():
     topic_list = TopicGroup()
     print(f"监听 Topics:{topic_list}")
 
-    threads = []
+    # 限制最大线程数为 topic 数量
+    pool = ThreadPoolExecutor(max_workers=len(topic_list))
+    thread_map: Dict[str, Future] = {}
+
     for topic in topic_list:
         mode = topic.split("_")[1]
-        t = threading.Thread(target=handle_message, args=(topic, mode,))
-        t.start()
-        threads.append(t)
+        future = pool.submit(handle_message, topic, mode)
+        thread_map[topic] = future
+
+        # 设置监控器:任务崩溃后自动重启
+        future.add_done_callback(lambda f, t=topic, m=mode: monitor_and_restart(f, t, m, pool, thread_map))
 
-    for t in threads:
-        t.join()
+    # 阻塞主线程防止退出(线程池会维持所有子线程)
+    try:
+        while True:
+            time.sleep(60)
+    except KeyboardInterrupt:
+        print("接收到退出指令,正在关闭线程池...")
+        pool.shutdown(wait=True)
 
 
 if __name__ == '__main__':
-    main()
+    main()

+ 5 - 0
utils/trace_utils.py

@@ -0,0 +1,5 @@
+import time
+import uuid
+
+def generate_trace_id():
+    return f"{uuid.uuid4().hex[:8]}-{int(time.time() * 1000)}"