main.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from fastapi import FastAPI, HTTPException, Request
  2. from fastapi.responses import JSONResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from starlette.responses import Response, StreamingResponse
  5. from utils.params import DecodeContentParam, PatternContentParam, TopicSearchParam
  6. from dotenv import load_dotenv, find_dotenv
  7. from typing import Any, Dict, List, Optional
  8. from tasks.decode import begin_decode_task
  9. from tasks.detail import get_decode_detail_by_task_id
  10. from tasks.pattern import begin_pattern_task
  11. from tasks.topic_search import search_topics
  12. from loguru import logger
  13. import sys
  14. logger.add(sink=sys.stderr, level="ERROR", backtrace=True, diagnose=True)
  15. # 接口访问日志(与仅 ERROR 的 sink 并存,默认 stderr 仍会输出 INFO)
  16. _MAX_ACCESS_LOG_BYTES = 8192
  17. def _preview_bytes(data: bytes) -> str:
  18. if not data:
  19. return ""
  20. text = data.decode("utf-8", errors="replace")
  21. if len(text) > _MAX_ACCESS_LOG_BYTES:
  22. return text[:_MAX_ACCESS_LOG_BYTES] + f"...<truncated len={len(text)}>"
  23. return text
  24. # 响应消息映射
  25. RESPONSE_MSG_MAP = {
  26. 0: "success",
  27. 1002: "视频不存在",
  28. 2001: "解构/聚类任务创建失败",
  29. -1: "failed",
  30. 404: "任务不存在"
  31. }
  32. load_dotenv(find_dotenv(), override=False)
  33. app = FastAPI()
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. @app.middleware("http")
  42. async def api_access_log_middleware(request: Request, call_next):
  43. """记录每个接口的请求(路径、查询、body)与响应(状态码、body)。"""
  44. body_bytes = await request.body()
  45. async def receive() -> dict:
  46. return {"type": "http.request", "body": body_bytes, "more_body": False}
  47. wrapped = Request(request.scope, receive)
  48. req_body_preview = _preview_bytes(body_bytes) if body_bytes else ""
  49. logger.info(
  50. "api_request method={} path={} query={} body={}",
  51. request.method,
  52. request.url.path,
  53. str(request.query_params),
  54. req_body_preview if req_body_preview else "<empty>",
  55. )
  56. response = await call_next(wrapped)
  57. if isinstance(response, StreamingResponse):
  58. logger.info(
  59. "api_response status={} path={} body=<streaming skipped>",
  60. response.status_code,
  61. request.url.path,
  62. )
  63. return response
  64. resp_chunks: List[bytes] = []
  65. async for chunk in response.body_iterator:
  66. resp_chunks.append(chunk)
  67. resp_body = b"".join(resp_chunks)
  68. resp_preview = _preview_bytes(resp_body) if resp_body else ""
  69. logger.info(
  70. "api_response status={} path={} body={}",
  71. response.status_code,
  72. request.url.path,
  73. resp_preview if resp_preview else "<empty>",
  74. )
  75. return Response(
  76. content=resp_body,
  77. status_code=response.status_code,
  78. headers=dict(response.headers),
  79. media_type=response.media_type,
  80. )
  81. @app.exception_handler(HTTPException)
  82. async def http_exception_handler(request: Request, exc: HTTPException):
  83. """统一处理 HTTPException,保证返回结构与其它接口一致"""
  84. msg = RESPONSE_MSG_MAP.get(exc.status_code, "failed")
  85. content = {
  86. "code": exc.status_code,
  87. "msg": msg,
  88. "data": None,
  89. }
  90. # 将异常 detail 写入 reason,便于排查问题
  91. if exc.detail:
  92. content["reason"] = str(exc.detail)
  93. return JSONResponse(status_code=200, content=content)
  94. def _build_api_response(
  95. code: int,
  96. data: Any = None,
  97. reason: Optional[str] = None
  98. ) -> JSONResponse:
  99. """构建统一的API响应"""
  100. msg = RESPONSE_MSG_MAP.get(code, "failed")
  101. content = {
  102. "code": code,
  103. "msg": msg,
  104. "data": data
  105. }
  106. # 失败时添加 reason 字段
  107. if code != 0 and reason:
  108. content["reason"] = reason
  109. return JSONResponse(status_code=200, content=content)
  110. @app.post("/api/v1/content/tasks/decode")
  111. def decode_content(param: DecodeContentParam):
  112. """创建解构任务"""
  113. res = begin_decode_task(param)
  114. code = res.get("code", -1)
  115. task_id = res.get("task_id")
  116. reason = res.get("reason", "")
  117. return _build_api_response(
  118. code=code,
  119. data={"task_id": task_id} if task_id else None,
  120. reason=reason
  121. )
  122. @app.get("/api/v1/content/tasks/{taskId}")
  123. def get_task_detail(taskId: str):
  124. """获取任务详情"""
  125. result = get_decode_detail_by_task_id(taskId)
  126. # 任务不存在
  127. if result is None:
  128. return _build_api_response(code=404, data=None)
  129. # 直接返回结果(已经包含 code、msg、data、reason)
  130. return JSONResponse(status_code=200, content=result)
  131. @app.post("/api/v1/content/tasks/pattern")
  132. def pattern_content(param: PatternContentParam):
  133. """创建聚类任务"""
  134. res = begin_pattern_task(param)
  135. code = res.get("code", -1)
  136. task_id = res.get("task_id")
  137. reason = res.get("reason", "")
  138. return _build_api_response(
  139. code=code,
  140. data={"task_id": task_id} if task_id else None,
  141. reason=reason
  142. )
  143. @app.post("/api/v1/content/topics/search")
  144. def search_content_topics(param: TopicSearchParam):
  145. """视频选题检索:根据关键词在解构结果中匹配,返回匹配度最高的 top5"""
  146. results = search_topics(param)
  147. return _build_api_response(code=0, data=results)