1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- from pydantic import BaseSettings, Field, validator, BaseModel
- from typing import Dict, Any, Optional, List
- class DatabaseConfig(BaseSettings):
- """数据库配置"""
- host: str = Field(..., env="DB_HOST")
- port: int = Field(..., env="DB_PORT")
- user: str = Field(..., env="DB_USER")
- password: str = Field(..., env="DB_PASSWORD")
- db: str = Field(..., env="DB_NAME")
- charset: str = Field("utf8mb4", env="DB_CHARSET")
- class Config:
- env_file = ".env"
- env_file_encoding = "utf-8"
- class RocketMQConfig(BaseSettings):
- """阿里云 RocketMQ 配置"""
- endpoint: str = Field(..., env="ROCKETMQ_ENDPOINT")
- access_key_id: str = Field(..., env="ROCKETMQ_ACCESS_KEY_ID")
- access_key_secret: str = Field(..., env="ROCKETMQ_ACCESS_KEY_SECRET")
- instance_id: str = Field(..., env="ROCKETMQ_INSTANCE_ID")
- wait_seconds: int = Field(10, env="ROCKETMQ_WAIT_SECONDS")
- batch: int = Field(1, env="ROCKETMQ_BATCH")
- class Config:
- env_file = ".env"
- class GlobalConfig(BaseSettings):
- """全局配置"""
- env: str = Field("prod", env="ENV")
- base_url: str = Field("https://api.production.com", env="BASE_URL")
- request_timeout: int = Field(30, env="REQUEST_TIMEOUT")
- log_level: str = Field("INFO", env="LOG_LEVEL")
- enable_aliyun_log: bool = Field(True, env="ENABLE_ALIYUN_LOG")
- class Config:
- env_file = ".env"
- class ResponseParse(BaseModel):
- """数据解析配置"""
- next_cursor: str = Field(..., description="下一页游标路径")
- data_path: str = Field(..., description="数据主体路径")
- fields: Dict[str, str] = Field(..., description="字段映射规则")
- class PlatformConfig(BaseModel):
- """平台配置"""
- platform: str
- mode: str
- path: str
- method: str = "POST"
- request_body: Dict[str, Any] = {}
- loop_times: int = 1
- loop_interval: int = 0
- response_parse: ResponseParse
- etl_hook: Optional[str] = None
- post_actions: Optional[List[PostAction]] = None
- @validator("etl_hook", pre=True)
- def resolve_etl_hook(cls, v):
- """动态加载钩子函数"""
- if not v:
- return None
- module_name, func_name = v.rsplit(".", 1)
- return getattr(importlib.import_module(module_name), func_name)
- class SpiderConfig(BaseModel):
- """全局配置容器"""
- default: dict = Field(...) # 全局默认配置
- platforms: Dict[str, PlatformConfig] = {}
- @classmethod
- def load(cls):
- """从 YAML 加载配置"""
- with open("config/config.yaml") as f:
- raw_config = yaml.safe_load(f)
- return cls(
- default=raw_config["default"],
- platforms=raw_config["platforms"]
- )
- class SpiderConfig(BaseSettings):
- """全局配置容器"""
- default: GlobalConfig
- database: DatabaseConfig
- mq: RocketMQConfig
- class Config:
- env_file = ".env"
- env_prefix = "SPIDER_" # 环境变量前缀
- case_sensitive = False # 环境变量不区分大小写
|