|
|
@@ -1,9 +1,10 @@
|
|
|
import os
|
|
|
import logging
|
|
|
-from fastapi import FastAPI, HTTPException
|
|
|
+from fastapi import FastAPI, HTTPException, Request
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from contextlib import asynccontextmanager
|
|
|
from dotenv import load_dotenv
|
|
|
+import time as time_module
|
|
|
|
|
|
from ..models.schemas import QuestionRequest, QueryResponse, HealthResponse
|
|
|
from ..agent.query_agent import QueryGenerationAgent
|
|
|
@@ -104,15 +105,49 @@ app = FastAPI(
|
|
|
)
|
|
|
|
|
|
# 添加CORS中间件
|
|
|
+# 从环境变量读取允许的来源,默认允许所有
|
|
|
+allowed_origins = os.getenv("CORS_ORIGINS", "*")
|
|
|
+if allowed_origins == "*":
|
|
|
+ origins = ["*"]
|
|
|
+else:
|
|
|
+ # 支持多个来源,用逗号分隔
|
|
|
+ origins = [origin.strip() for origin in allowed_origins.split(",")]
|
|
|
+
|
|
|
+logger.info(f"CORS配置 - 允许的来源: {origins}")
|
|
|
+
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
- allow_origins=["*"],
|
|
|
+ allow_origins=origins,
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
+ expose_headers=["*"], # 允许浏览器访问所有响应头
|
|
|
)
|
|
|
|
|
|
|
|
|
+# 添加请求日志中间件
|
|
|
+@app.middleware("http")
|
|
|
+async def log_requests(request: Request, call_next):
|
|
|
+ """记录所有HTTP请求"""
|
|
|
+ start_time = time_module.time()
|
|
|
+
|
|
|
+ # 记录请求信息
|
|
|
+ logger.info(f"收到请求: {request.method} {request.url.path} - 来源: {request.headers.get('origin', 'N/A')} - User-Agent: {request.headers.get('user-agent', 'N/A')[:50]}")
|
|
|
+
|
|
|
+ # 如果是OPTIONS预检请求,特别标注
|
|
|
+ if request.method == "OPTIONS":
|
|
|
+ logger.info(f"OPTIONS预检请求 - Path: {request.url.path}")
|
|
|
+
|
|
|
+ # 处理请求
|
|
|
+ response = await call_next(request)
|
|
|
+
|
|
|
+ # 记录响应信息
|
|
|
+ process_time = time_module.time() - start_time
|
|
|
+ logger.info(f"响应: {request.method} {request.url.path} - 状态码: {response.status_code} - 耗时: {process_time:.3f}s")
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
|
async def health_check():
|
|
|
"""健康检查接口"""
|