settings.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from pydantic import BaseSettings, Field, validator, BaseModel
  2. from typing import Dict, Any, Optional, List
  3. class DatabaseConfig(BaseSettings):
  4. """数据库配置"""
  5. host: str = Field(..., env="DB_HOST")
  6. port: int = Field(..., env="DB_PORT")
  7. user: str = Field(..., env="DB_USER")
  8. password: str = Field(..., env="DB_PASSWORD")
  9. db: str = Field(..., env="DB_NAME")
  10. charset: str = Field("utf8mb4", env="DB_CHARSET")
  11. class Config:
  12. env_file = ".env"
  13. env_file_encoding = "utf-8"
  14. class RocketMQConfig(BaseSettings):
  15. """阿里云 RocketMQ 配置"""
  16. endpoint: str = Field(..., env="ROCKETMQ_ENDPOINT")
  17. access_key_id: str = Field(..., env="ROCKETMQ_ACCESS_KEY_ID")
  18. access_key_secret: str = Field(..., env="ROCKETMQ_ACCESS_KEY_SECRET")
  19. instance_id: str = Field(..., env="ROCKETMQ_INSTANCE_ID")
  20. wait_seconds: int = Field(10, env="ROCKETMQ_WAIT_SECONDS")
  21. batch: int = Field(1, env="ROCKETMQ_BATCH")
  22. class Config:
  23. env_file = ".env"
  24. class GlobalConfig(BaseSettings):
  25. """全局配置"""
  26. env: str = Field("prod", env="ENV")
  27. base_url: str = Field("https://api.production.com", env="BASE_URL")
  28. request_timeout: int = Field(30, env="REQUEST_TIMEOUT")
  29. log_level: str = Field("INFO", env="LOG_LEVEL")
  30. enable_aliyun_log: bool = Field(True, env="ENABLE_ALIYUN_LOG")
  31. class Config:
  32. env_file = ".env"
  33. class ResponseParse(BaseModel):
  34. """数据解析配置"""
  35. next_cursor: str = Field(..., description="下一页游标路径")
  36. data_path: str = Field(..., description="数据主体路径")
  37. fields: Dict[str, str] = Field(..., description="字段映射规则")
  38. class PlatformConfig(BaseModel):
  39. """平台配置"""
  40. platform: str
  41. mode: str
  42. path: str
  43. method: str = "POST"
  44. request_body: Dict[str, Any] = {}
  45. loop_times: int = 1
  46. loop_interval: int = 0
  47. response_parse: ResponseParse
  48. etl_hook: Optional[str] = None
  49. post_actions: Optional[List[PostAction]] = None
  50. @validator("etl_hook", pre=True)
  51. def resolve_etl_hook(cls, v):
  52. """动态加载钩子函数"""
  53. if not v:
  54. return None
  55. module_name, func_name = v.rsplit(".", 1)
  56. return getattr(importlib.import_module(module_name), func_name)
  57. class SpiderConfig(BaseModel):
  58. """全局配置容器"""
  59. default: dict = Field(...) # 全局默认配置
  60. platforms: Dict[str, PlatformConfig] = {}
  61. @classmethod
  62. def load(cls):
  63. """从 YAML 加载配置"""
  64. with open("config/config.yaml") as f:
  65. raw_config = yaml.safe_load(f)
  66. return cls(
  67. default=raw_config["default"],
  68. platforms=raw_config["platforms"]
  69. )
  70. class SpiderConfig(BaseSettings):
  71. """全局配置容器"""
  72. default: GlobalConfig
  73. database: DatabaseConfig
  74. mq: RocketMQConfig
  75. class Config:
  76. env_file = ".env"
  77. env_prefix = "SPIDER_" # 环境变量前缀
  78. case_sensitive = False # 环境变量不区分大小写