Sfoglia il codice sorgente

整体改为进程+协程运行模式

zhangliang 1 settimana fa
parent
commit
fc8c21b105

+ 22 - 0
.env.prod

@@ -0,0 +1,22 @@
+# 环境配置
+ENV=prod
+LOG_LEVEL=INFO
+ENABLE_ALIYUN_LOG=true
+
+
+# 数据库配置
+DB_HOST="rm-bp1159bu17li9hi94.mysql.rds.aliyuncs.com"
+DB_PORT=3306
+DB_USER="crawler"
+DB_PASSWORD="crawler123456@"
+DB_NAME="piaoquan-crawler"
+DB_CHARSET="utf8mb4"
+
+# RocketMQ 阿里云配置(HTTP 协议)
+
+ROCKETMQ_ENDPOINT="http://1894469520484605.mqrest.cn-qingdao-public.aliyuncs.com"
+ROCKETMQ_ACCESS_KEY_ID="LTAI4G7puhXtLyHzHQpD6H7A"
+ROCKETMQ_ACCESS_KEY_SECRET="nEbq3xWNQd1qLpdy2u71qFweHkZjSG"
+ROCKETMQ_INSTANCE_ID="MQ_INST_1894469520484605_BXhXuzkZ"
+ROCKETMQ_WAIT_SECONDS=10
+ROCKETMQ_BATCH=1

+ 7 - 5
.env → .env.test

@@ -1,14 +1,16 @@
 # 环境配置
-ENV=prod
+ENV=test
 LOG_LEVEL=INFO
 ENABLE_ALIYUN_LOG=true
 
+
 # 数据库配置
-DB_HOST=127.0.0.1
+DB_HOST="rm-bp1159bu17li9hi94.mysql.rds.aliyuncs.com"
 DB_PORT=3306
-DB_USER=root
-DB_PASSWORD=123456
-DB_NAME=crawler_db
+DB_USER="crawler"
+DB_PASSWORD="crawler123456@"
+DB_NAME="piaoquan-crawler"
+DB_CHARSET="utf8mb4"
 
 # 消息队列配置(RabbitMQ/其他)
 MQ_HOST=localhost

+ 0 - 0
scheduler/__init__.py → application/base/__init__.py


+ 135 - 0
application/base/async_mysql_client.py

@@ -0,0 +1,135 @@
+"""
+文件功能:
+    异步 MySQL 客户端封装(基于 asyncmy):
+    - 自动管理连接池
+    - 支持 async with 上下文管理
+    - 提供常见方法:fetch_all, fetch_one, execute, executemany
+    - 内部单例复用,避免重复创建连接池
+
+适用场景:
+    - 高并发异步任务系统
+    - 通用业务数据库访问组件
+"""
+
+import asyncmy
+from typing import List, Dict, Any, Optional
+
+
+class AsyncMySQLClient:
+    """
+    通用异步 MySQL 客户端,基于 asyncmy 实现
+    """
+
+    # 类变量用于单例连接池
+    _instance: Optional["AsyncMySQLClient"] = None
+
+    def __new__(cls, *args, **kwargs):
+        """
+        单例模式,确保同一配置只创建一个连接池实例
+        """
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def __init__(
+        self,
+        host: str,
+        port: int,
+        user: str,
+        password: str,
+        db: str,
+        charset: str,
+        minsize: int = 1,
+        maxsize: int = 5,
+    ):
+        self._db_settings = {
+            "host": host,
+            "port": port,
+            "user": user,
+            "password": password,
+            "db": db,
+            "autocommit": True,
+            "charset": charset,
+        }
+        self._minsize = minsize
+        self._maxsize = maxsize
+        self._pool: Optional[asyncmy.Pool] = None
+
+    async def __aenter__(self):
+        """支持 async with 上下文初始化连接池"""
+        await self.init_pool()
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        """支持 async with 自动关闭连接池"""
+        await self.close()
+
+    async def init_pool(self):
+        """
+        初始化连接池(如未初始化)
+        """
+        if not self._pool:
+            self._pool = await asyncmy.create_pool(
+                **self._db_settings,
+                minsize=self._minsize,
+                maxsize=self._maxsize,
+            )
+
+    async def close(self):
+        """
+        关闭连接池
+        """
+        if self._pool:
+            self._pool.close()
+            await self._pool.wait_closed()
+            self._pool = None
+
+    async def fetch_all(self, sql: str, params: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
+        """
+        查询多行数据,返回字典列表
+        """
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.execute(sql, params or [])
+                rows = await cur.fetchall()
+                columns = [desc[0] for desc in cur.description]  # 获取字段名列表
+                # 转换每一行为字典
+                result = [dict(zip(columns, row)) for row in rows]
+                return result
+
+    async def fetch_one(self, sql: str, params: Optional[List[Any]] = None) -> Optional[Dict[str, Any]]:
+        """
+        查询单行数据,返回字典
+        """
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.execute(sql, params or [])
+                row = await cur.fetchone()
+                if row is None:
+                    return None
+                columns = [desc[0] for desc in cur.description]
+                return dict(zip(columns, row))
+
+    async def execute(self, sql: str, params: Optional[List[Any]] = None) -> int:
+        """
+        执行单条写操作(insert/update/delete)
+        :param sql: SQL 语句
+        :param params: 参数列表
+        :return: 影响行数
+        """
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.execute(sql, params or [])
+                return cur.rowcount
+
+    async def executemany(self, sql: str, params_list: List[List[Any]]) -> int:
+        """
+        批量执行写操作
+        :param sql: SQL 语句
+        :param params_list: 多组参数
+        :return: 总影响行数
+        """
+        async with self._pool.acquire() as conn:
+            async with conn.cursor() as cur:
+                await cur.executemany(sql, params_list)
+                return cur.rowcount

+ 84 - 0
application/base/async_rocketmq_consumer.py

@@ -0,0 +1,84 @@
+import asyncio
+import json
+from typing import List, Optional
+from mq_http_sdk.mq_client import MQClient
+from mq_http_sdk.mq_exception import MQExceptionBase
+from mq_http_sdk.consumer import Message
+
+
+class AsyncRocketMQConsumer:
+    """
+    阿里云 RocketMQ HTTP 协议异步消费者封装类
+    - 基于 asyncio 实现原生异步消费模型
+    - 支持长轮询批量拉取消息
+    - 手动确认消费
+    """
+
+    def __init__(
+        self,
+        endpoint: str,
+        access_key_id: str,
+        access_key_secret: str,
+        instance_id: str,
+        topic_name: str,
+        group_id: str,
+        wait_seconds: int = 3,
+        batch: int = 1,
+    ):
+        self.endpoint = endpoint
+        self.access_key_id = access_key_id
+        self.access_key_secret = access_key_secret
+        self.instance_id = instance_id
+        self.topic_name = topic_name
+        self.group_id = group_id
+        self.wait_seconds = wait_seconds
+        self.batch = batch
+
+        # 初始化客户端
+        self.client = MQClient(self.endpoint, self.access_key_id, self.access_key_secret)
+        self.consumer = self.client.get_consumer(self.instance_id, self.topic_name, self.group_id)
+
+    async def receive_messages(self) -> List[Message]:
+        """
+        异步方式拉取消息(内部调用同步 SDK,用 asyncio.to_thread 包装)
+        """
+        try:
+            return await asyncio.to_thread(
+                self.consumer.receive_message,
+                self.batch,
+                self.wait_seconds,
+            )
+        except MQExceptionBase as e:
+            if hasattr(e, "type") and e.type == "MessageNotExist":
+                return []
+            else:
+                raise e
+
+    async def ack_message(self, receipt_handle: str) -> None:
+        """
+        确认消息已成功消费
+        """
+        try:
+            await asyncio.to_thread(self.consumer.ack_message, [receipt_handle])
+        except Exception as e:
+            raise RuntimeError(f"确认消息失败: {e}")
+
+    async def run_forever(self, handler: callable):
+        """
+        启动消费循环,不断拉取消息并调用处理函数
+
+        :param handler: async 函数,接收参数 message: Message
+        """
+        print(f"[AsyncRocketMQConsumer] 开始消费 Topic={self.topic_name} Group={self.group_id}")
+        while True:
+            try:
+                messages = await self.receive_messages()
+                for msg in messages:
+                    try:
+                        await handler(msg)
+                        await self.ack_message(msg.receipt_handle)
+                    except Exception as e:
+                        print(f"处理消息失败: {e}\n消息内容: {msg.message_body}")
+            except Exception as e:
+                print(f"拉取消息异常: {e}")
+                await asyncio.sleep(2)

+ 0 - 1
application/config/__init__.py

@@ -1,3 +1,2 @@
 from .ipconfig import ip_config
 from .mysql_config import env_dict
-from .topic_group_queue import TopicGroup

+ 1 - 1
application/config/common/log/local_log.py

@@ -3,7 +3,7 @@ import sys
 from datetime import date, timedelta, datetime
 from loguru import logger
 from pathlib import Path
-from utils.project_paths import log_dir
+from utils.path_utils import log_dir
 
 class Local:
     # 日期常量

+ 69 - 0
application/functions/async_mysql_service.py

@@ -0,0 +1,69 @@
+# application/functions/async_mysql_service.py
+import asyncio
+import json
+import os
+from typing import List, Optional, Dict, Any
+from application.base.async_mysql_client import AsyncMySQLClient
+from utils.env_loader import load_env
+
+
+class AsyncMysqlService:
+    """
+    项目业务逻辑封装类,基于 AsyncMySQLClient 实现异步数据库访问
+
+    功能:
+    - 封装与业务相关的 SQL 操作
+    - 自动读取环境变量进行配置初始化
+    - 与爬虫、任务处理逻辑解耦
+    """
+
+    def __init__(self):
+        """
+        初始化时自动从环境变量读取配置并构造底层连接池客户端
+        """
+        db_config = {
+            "host": os.getenv("DB_HOST"),
+            "port": int(os.getenv("DB_PORT")),
+            "user": os.getenv("DB_USER"),
+            "password": os.getenv("DB_PASSWORD"),
+            "db": os.getenv("DB_NAME"),
+            "charset": os.getenv("DB_CHARSET")
+        }
+        self.client = AsyncMySQLClient(**db_config)
+
+    async def init(self):
+        """连接池初始化,在服务启动时调用一次"""
+        await self.client.init_pool()
+
+    async def get_user_list(self,id) -> List[Dict[str, Any]]:
+        sql = "SELECT uid, link, nick_name from crawler_user_v3 where task_id = %s"
+        return await self.client.fetch_all(sql, [id])
+
+    async def get_rule_dict(self, rule_id: int) -> Optional[Dict[str, Any]]:
+        sql = "SELECT rule FROM crawler_task_v3 WHERE id = %s"
+        row = await self.client.fetch_one(sql, [rule_id])
+        if not row or "rule" not in row:
+            return None
+
+        try:
+            # 合并 list[dict] 为一个 dict
+            return {k: v for item in json.loads(row["rule"]) for k, v in item.items()}
+        except json.JSONDecodeError as e:
+            print(f"[get_rule_dict] JSON 解析失败: {e}")
+            return None
+
+
+async def main():
+    mysql_service = AsyncMysqlService()
+    await mysql_service.init()
+    users = await mysql_service.get_user_list(18)
+    rules = await mysql_service.get_rule_dict(18)
+    print(users)
+    await mysql_service.client.close()
+
+
+if __name__ == '__main__':
+    asyncio.run(main())
+
+
+

+ 82 - 0
application/functions/rocketmq_consumer.py

@@ -0,0 +1,82 @@
+import asyncio
+import json
+import os
+from typing import List, Optional, Callable
+from mq_http_sdk.mq_client import MQClient
+from mq_http_sdk.mq_exception import MQExceptionBase
+from mq_http_sdk.consumer import Message
+
+from utils.env_loader import load_env, get_env, get_int_env  # 如果你有统一封装
+
+# 确保环境加载
+load_env()
+
+
+class AsyncRocketMQConsumer:
+    """
+    阿里云 RocketMQ HTTP 协议异步消费者封装类
+    - 支持自动读取环境变量
+    - 基于 asyncio 实现原生异步消费模型
+    - 手动确认消费
+    """
+
+    def __init__(
+        self,
+        topic_name: Optional[str],
+        group_id: Optional[str],
+        wait_seconds: Optional[int] = None,
+        batch: Optional[int] = None,
+    ):
+        # 从环境变量读取配置
+        self.endpoint = get_env("ROCKETMQ_ENDPOINT")
+        self.access_key_id = get_env("ROCKETMQ_AK")
+        self.access_key_secret = get_env("ROCKETMQ_SK")
+        self.instance_id = get_env("ROCKETMQ_INSTANCE_ID")
+        self.wait_seconds = wait_seconds or get_int_env("ROCKETMQ_WAIT_SECONDS", 10)
+        self.batch = batch or get_int_env("ROCKETMQ_BATCH", 1)
+
+        # 初始化客户端
+        self.client = MQClient(self.endpoint, self.access_key_id, self.access_key_secret)
+        self.consumer = self.client.get_consumer(self.instance_id, topic_name, group_id)
+
+    async def receive_messages(self) -> List[Message]:
+        """异步封装消息拉取"""
+        try:
+            return await asyncio.to_thread(
+                self.consumer.receive_message,
+                self.batch,
+                self.wait_seconds,
+            )
+        except MQExceptionBase as e:
+            if getattr(e, "type", "") == "MessageNotExist":
+                return []
+            raise e
+
+    async def ack_message(self, receipt_handle: str) -> None:
+        """确认消费成功"""
+        try:
+            await asyncio.to_thread(self.consumer.ack_message, [receipt_handle])
+        except Exception as e:
+            raise RuntimeError(f"确认消息失败: {e}")
+
+    async def run_forever(self, handler: Callable[[Message], asyncio.Future]):
+        """
+        无限循环拉取消息并处理,适合开发调试或小批量任务
+
+        :param handler: 异步消息处理函数 async def handler(msg: Message)
+        """
+        print(f"[AsyncRocketMQConsumer] 启动消费: Topic={self.topic_name}, Group={self.group_id}")
+        while True:
+            try:
+                messages = await self.receive_messages()
+                for msg in messages:
+                    try:
+                        await handler(msg)
+                        await self.ack_message(msg.receipt_handle)
+                    except Exception as e:
+                        print(f"[处理失败] {e}\n消息: {msg.message_body}")
+            except Exception as e:
+                print(f"[拉取失败] {e}")
+                await asyncio.sleep(2)
+
+

+ 1 - 1
application/spiders/base_spider.py

@@ -3,7 +3,7 @@ import aiohttp
 from abc import ABC
 from typing import List, Dict, Optional
 import time
-from application.config.common import LoggerManager
+from application.config.common.log.logger_manager import LoggerManager
 from utils.extractors import safe_extract
 from application.config.common import MQ
 from utils.config_loader import ConfigLoader  # 新增导入

+ 74 - 56
main.py

@@ -5,88 +5,101 @@ from multiprocessing import Process, cpu_count
 from typing import List, Dict
 import asyncio
 
-from application.config.common import LoggerManager
+from application.config.common.log.logger_manager import LoggerManager
 from utils.trace_utils import generate_trace_id
 from application.config.common import get_consumer, ack_message
-from application.functions import MysqlService
+from application.functions.async_mysql_service import AsyncMysqlService
 from application.spiders.spider_registry import get_spider_class, SPIDER_CLASS_MAP
-
+from application.functions.rocketmq_consumer import AsyncRocketMQConsumer
 
 # ------------------------------- Topic 协程处理核心 -------------------------------
+
+# 每个进程共享的 mysql service 实例(全局变量)
+mysql_service: AsyncMysqlService = None
+
+
 async def async_handle_topic(topic: str):
-    consumer = get_consumer(topic_name=topic, group_id=topic)
+    """
+    单个 topic 的消费逻辑,运行在协程中:
+    - 从 MQ 中消费消息;
+    - 根据消息内容执行对应爬虫;
+    - 使用异步数据库服务查询配置;
+    - 记录日志、确认消息。
+    """
     logger = LoggerManager.get_logger(topic, "worker")
     aliyun_logger = LoggerManager.get_aliyun_logger(topic, "worker")
 
-    while True:
+    # 每个 topic 创建独立的 consumer 实例
+    consumer = AsyncRocketMQConsumer(topic_name=topic, group_id=topic)
+
+    async def handle_single_message(message):
+        trace_id = generate_trace_id()
         try:
-            messages = consumer.consume_message(wait_seconds=10, batch_size=1)
-            if not messages:
-                await asyncio.sleep(1)
-                continue
-
-            for message in messages:
-                trace_id = generate_trace_id()
-                try:
-                    payload = json.loads(message.message_body)
-                    platform = payload["platform"]
-                    mode = payload["mode"]
-                    task_id = payload["id"]
-
-                    mysql_service = MysqlService(platform, mode, task_id)
-                    user_list = mysql_service.get_user_list()
-                    rule_dict = mysql_service.get_rule_dict()
-
-                    CrawlerClass = get_spider_class(topic)
-                    crawler = CrawlerClass(
-                        rule_dict=rule_dict,
-                        user_list=user_list,
-                        trace_id=trace_id
-                    )
-
-                    await crawler.run()
-
-                    ack_message(mode, platform, message, consumer, trace_id=trace_id)
-                    aliyun_logger.logging(code="1000", message="任务成功完成并确认消息", trace_id=trace_id)
-
-                except Exception as e:
-                    aliyun_logger.logging(
-                        code="9001",
-                        message=f"处理消息失败: {e}\n{traceback.format_exc()}",
-                        trace_id=trace_id,
-                        data=message.message_body,
-                    )
-        except Exception as err:
-            logger.error(f"[{topic}] 消费失败: {err}\n{traceback.format_exc()}")
-            await asyncio.sleep(5)
+            payload = json.loads(message.message_body)
+            platform = payload["platform"]
+            mode = payload["mode"]
+            task_id = payload["id"]
+
+            user_list = await mysql_service.get_user_list(task_id)
+            rule_dict = await mysql_service.get_rule_dict(task_id)
+
+            CrawlerClass = get_spider_class(topic)
+            crawler = CrawlerClass(
+                rule_dict=rule_dict,
+                user_list=user_list,
+                trace_id=trace_id
+            )
+            await crawler.run()
+
+            # ack 由 run 成功后执行
+            await consumer.ack_message(message.receipt_handle)
+            aliyun_logger.logging(code="1000", message="任务成功完成并确认消息", trace_id=trace_id)
+
+        except Exception as e:
+            aliyun_logger.logging(
+                code="9001",
+                message=f"处理消息失败: {e}\n{traceback.format_exc()}",
+                trace_id=trace_id,
+                data=message.message_body,
+            )
+
+    # 消费循环启动
+    await consumer.run_forever(handle_single_message)
 
 
 async def run_all_topics(topics: List[str]):
     """
-       启动当前进程内所有 topic 的协程任务。
+    启动当前进程中所有 topic 的协程监听任务。
+    初始化全局 AsyncMysqlService 实例。
     """
+    global mysql_service
+    mysql_service = AsyncMysqlService()
+    await mysql_service.init()  # 初始化连接池
+
     tasks = [asyncio.create_task(async_handle_topic(topic)) for topic in topics]
     await asyncio.gather(*tasks)
 
 
 def handle_topic_group(topics: List[str]):
     """
-        子进程入口:运行一个事件循环,处理当前分组内的所有 topics。
+    子进程入口函数:
+    启动异步事件循环处理该组 topics。
     """
     asyncio.run(run_all_topics(topics))
 
 
 # ------------------------------- 主调度部分 -------------------------------
+
 def split_topics(topics: List[str], num_groups: int) -> List[List[str]]:
     """
-        将所有 topic 平均分配为 num_groups 组(用于多个进程)
+    将所有 topic 平均划分为 num_groups 组,用于分配给子进程
     """
     return [topics[i::num_groups] for i in range(num_groups)]
 
 
 def start_worker_process(group_id: int, topic_group: List[str], process_map: Dict[int, Process]):
     """
-    启动一个新的子进程处理一组 topic。
+    启动一个子进程处理一组 topic。
     """
     p = Process(target=handle_topic_group, args=(topic_group,), name=f"Worker-{group_id}")
     p.start()
@@ -95,31 +108,36 @@ def start_worker_process(group_id: int, topic_group: List[str], process_map: Dic
 
 
 def main():
-    # 获取所有已注册的爬虫 topic 列表
+    """
+    主调度入口:
+    - 获取全部爬虫 topic;
+    - 按 CPU 核心数分组;
+    - 启动子进程运行;
+    - 监控子进程状态,自动恢复。
+    """
     topic_list = list(SPIDER_CLASS_MAP.keys())
-    print(f"监听 Topics: {topic_list}")
+    print(f"[主进程] 监听 Topics: {topic_list}")
 
-    # 使用 CPU 核心数决定进程数
     num_cpus = cpu_count()
     topic_groups = split_topics(topic_list, num_cpus)
-    print(f"CPU 核心数: {num_cpus}, 启动进程数: {len(topic_groups)}")
+    print(f"[主进程] CPU 核心数: {num_cpus},将启动进程数: {len(topic_groups)}")
 
-    # 启动子进程
     process_map: Dict[int, Process] = {}
+
     for group_id, topic_group in enumerate(topic_groups):
         start_worker_process(group_id, topic_group, process_map)
 
-    # 持续监控子进程状态,异常退出自动重启
+    # 主进程持续监控子进程状态
     try:
         while True:
             time.sleep(5)
             for group_id, p in list(process_map.items()):
                 if not p.is_alive():
-                    print(f"[监控] 进程 {p.name} PID={p.pid} 已崩溃,尝试重启中...")
+                    print(f"[监控] 进程 {p.name} PID={p.pid} 已崩溃,正在重启...")
                     time.sleep(2)
                     start_worker_process(group_id, topic_groups[group_id], process_map)
     except KeyboardInterrupt:
-        print("接收到退出信号,终止所有进程...")
+        print("[主进程] 接收到退出信号,终止所有进程...")
         for p in process_map.values():
             p.terminate()
         for p in process_map.values():

+ 10 - 0
run.sh

@@ -0,0 +1,10 @@
+#!/bin/bash
+
+# 支持传入环境参数,默认 prod
+ENV=${1:-prod}
+
+export ENV
+
+echo "当前运行环境: $ENV"
+
+python3 main.py

+ 0 - 41
scheduler/scheduler_main.py

@@ -1,41 +0,0 @@
-# scheduler_main.py - 爬虫调度主程序
-import asyncio
-import traceback
-import sys
-import os
-from application.config.common import AliyunLogger
-from application.spiders.universal_crawler import AsyncCrawler
-
-
-async def main():
-    """主函数"""
-    # 设置日志
-    logger = AliyunLogger(platform="system", mode="manager")
-
-    try:
-        # 从环境变量获取配置
-        config_topic = os.getenv("CONFIG_TOPIC", "crawler_config")
-        config_group = os.getenv("CONFIG_GROUP", "crawler_config_group")
-
-        # 创建爬虫控制器
-        controller = AsyncCrawler(
-            platform: str,
-            mode: str,
-        )
-        # 启动控制器
-        await controller.run()
-
-        # 保持主线程运行
-        while True:
-            await asyncio.sleep(60)
-
-    except Exception as e:
-        tb = traceback.format_exc()
-        message = f"主程序发生错误: {e}\n{tb}"
-        logger.logging(code="1006", message=message)
-        sys.exit(1)
-
-
-if __name__ == "__main__":
-    # 运行主事件循环
-    asyncio.run(main())

+ 2 - 11
test/test1.py

@@ -1,15 +1,6 @@
 import asyncio
 import time
 
+topics = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
 
-def f1():
-    for i in range(100,10000):
-        time.sleep(10)
-        print(i)
-async def run():
-    print(1)
-    f1()
-    print(2)
-
-
-asyncio.run(run())
+print(topics[1::8])

+ 1 - 1
utils/config_loader.py

@@ -1,7 +1,7 @@
 import yaml
 import os
 from urllib.parse import urljoin
-from utils.project_paths import config_spiders_path
+from utils.path_utils import config_spiders_path
 
 
 class ConfigLoader:

+ 16 - 4
utils/env_loader.py

@@ -1,9 +1,16 @@
 import os
 from dotenv import load_dotenv
 
-# 支持 .env 文件自动加载
-dotenv_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
-load_dotenv(dotenv_path)
+def load_env(env: str = None):
+    """
+    根据传入的环境名加载对应的 .env 文件,默认加载 .env.prod
+    """
+    if env is None:
+        env = os.getenv("ENV", "prod")  # 默认prod
+
+    dotenv_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), f".env.{env}")
+    load_dotenv(dotenv_path)
+    print(f"加载环境配置文件: {dotenv_path}")
 
 def get_env(key: str, default: str = "") -> str:
     """获取环境变量"""
@@ -13,5 +20,10 @@ def get_int_env(key: str, default: int = 0) -> int:
     """获取整数类型环境变量"""
     try:
         return int(os.getenv(key, default))
-    except ValueError:
+    except (TypeError, ValueError):
         return default
+
+
+# 自动加载环境变量
+load_env()
+