Selaa lähdekoodia

feat: 完成开发

huangzhichao 1 kuukausi sitten
vanhempi
commit
c98a8eeb86

+ 2 - 0
README.md

@@ -104,3 +104,5 @@ uvicorn app.main:app --reload --port 8000
 - 测试:新增 `pytest`,并添加 `tests/` 目录
 - 代码质量:`ruff`/`black`/`mypy`,及 `pre-commit`
 - 容器化:编写 `Dockerfile` 与 `docker-compose.yml`
+
+> http://rescdn.yishihui.com/temp/1757907756869_full_temp.jpg

+ 7 - 0
app/api/deps.py

@@ -1,10 +1,17 @@
 from fastapi import Depends
 
 from ..providers.speech_provider import SpeechProvider
+from ..providers.understand_image_provider import UnderstandImageProvider    
 from ..services.speech_service import SpeechService
+from ..services.vl_service import VLService
 
 
 def get_speech_service() -> SpeechService:
     provider = SpeechProvider()
     return SpeechService(provider)
 
+
+def get_understand_image_service() -> VLService:
+    provider = UnderstandImageProvider()
+    return VLService(provider)
+

+ 9 - 3
app/api/routes.py

@@ -1,7 +1,8 @@
 from fastapi import APIRouter, Depends
-from .deps import get_speech_service
-from ..schemas.speech import TextToSpeechResponse, TextToSpeechRequest
+from .deps import get_speech_service, get_understand_image_service
+from ..schemas.base import DataResponse, TextToSpeechRequest, UnderstandImageRequest
 from ..services.speech_service import SpeechService
+from ..services.vl_service import VLService
 
 
 router = APIRouter()
@@ -11,7 +12,12 @@ router = APIRouter()
 def ping():
     return {"message": "pong"}
 
-@router.post('/llm/text-to-speech', response_model=TextToSpeechResponse, tags=["llm"])
+@router.post('/llm/text-to-speech', response_model=DataResponse, tags=["llm"])
 def text_to_speech(req: TextToSpeechRequest, service: SpeechService = Depends(get_speech_service)):
     return service.text_to_speech(req)
 
+
+@router.post('/llm/understand-image', response_model=DataResponse, tags=["understand-image"])
+def understand_image(req: UnderstandImageRequest, service: VLService = Depends(get_understand_image_service)):
+    return service.understand_image(req)
+

+ 2 - 2
app/providers/base.py

@@ -1,6 +1,6 @@
 from typing import List, Optional, Protocol
 
-from ..schemas.speech import ChatMessage, ChatResponse, TextToSpeechResponse
+from ..schemas.base import ChatMessage, ChatResponse, DataResponse
 
 
 class LLMProvider(Protocol):
@@ -20,5 +20,5 @@ class SpeechProvider(Protocol):
         text: str,
         *,
         model: Optional[str] = None,
-    ) -> TextToSpeechResponse:
+    ) -> DataResponse:
         ...

+ 11 - 11
app/providers/speech_provider.py

@@ -11,7 +11,7 @@ from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthes
 
 import requests
 
-from ..schemas.speech import TextToSpeechResponse, DataPayload
+from ..schemas.base import DataResponse, TextToSpeechPayload
 from ..core.config import get_settings
 from ..core.logger import get_logger
 
@@ -30,7 +30,7 @@ def _safe_filename(name: str) -> str:
 
 
 class SpeechProvider:
-    def text_to_speech(self, volume: int, pitch: float, rate: float, filename: str, text: str, *, model: Optional[str] = None, format: Optional[str] = None) -> TextToSpeechResponse:
+    def text_to_speech(self, volume: int, pitch: float, rate: float, filename: str, text: str, *, model: Optional[str] = None, format: Optional[str] = None) -> DataResponse:
         # Resolve output path under project-root/temp and ensure directory exists
         project_root = Path(__file__).resolve().parents[2]  # repo root
         audio_dir = project_root / "temp"
@@ -38,21 +38,21 @@ class SpeechProvider:
             audio_dir.mkdir(parents=True, exist_ok=True)
         except Exception as e:
             logger.error("Failed to create audio directory %s: %s", audio_dir, e, exc_info=True)
-            return TextToSpeechResponse(code=1, data=None, msg=f"create audio dir failed: {e}")
+            return DataResponse(code=1, data=None, msg=f"create audio dir failed: {e}")
 
         # Basic input validation
         if not isinstance(text, str) or not text.strip():
             msg = "text is required"
             logger.error(msg)
-            return TextToSpeechResponse(code=1, data=None, msg=msg)
+            return DataResponse(code=1, data=None, msg=msg)
         if not isinstance(filename, str) or not filename.strip():
             msg = "filename is required"
             logger.error(msg)
-            return TextToSpeechResponse(code=1, data=None, msg=msg)
+            return DataResponse(code=1, data=None, msg=msg)
         if not dashscope.api_key:
             msg = "DASHSCOPE_API_KEY is missing"
             logger.error(msg)
-            return TextToSpeechResponse(code=1, data=None, msg=msg)
+            return DataResponse(code=1, data=None, msg=msg)
 
         # determine desired output format (default mp3 for smaller size)
         audio_format = (format or 'mp3').lower()
@@ -97,7 +97,7 @@ class SpeechProvider:
                 callback.on_complete()
             except Exception:
                 pass
-            return TextToSpeechResponse(code=1, data=None, msg=str(e))
+            return DataResponse(code=1, data=None, msg=str(e))
 
         if callback.had_error:
             # TTS reported an error via callback
@@ -108,7 +108,7 @@ class SpeechProvider:
             else:
                 msg = base_msg
             logger.error("TTS callback error: %s", msg)
-            return TextToSpeechResponse(code=1, data=None, msg=msg)
+            return DataResponse(code=1, data=None, msg=msg)
 
         # After synthesis completes, upload the file to OSS
         try:
@@ -118,15 +118,15 @@ class SpeechProvider:
                 Path(out_path).unlink(missing_ok=True)
             except Exception as del_err:
                 logger.warning("Failed to delete local audio %s: %s", out_path, del_err)
-            return TextToSpeechResponse(
+            return DataResponse(
                 code=0,
-                data=DataPayload(audio_url=url),
+                data=TextToSpeechPayload(audio_url=url),
                 msg='success'
             )
         except Exception as e:
             # Keep local file for inspection; report error message
             logger.error("Upload failed", exc_info=True)
-            return TextToSpeechResponse(code=1, data=None, msg=str(e))
+            return DataResponse(code=1, data=None, msg=str(e))
 
 
 class Callback(ResultCallback):

+ 117 - 0
app/providers/understand_image_provider.py

@@ -0,0 +1,117 @@
+from openai import OpenAI
+from ..schemas.base import DataResponse
+from ..core.config import get_settings
+from ..core.logger import get_logger
+from openai.types.chat import ChatCompletionToolParam
+import json
+
+settings = get_settings()
+logger = get_logger("understand_image_provider")
+
+SYSTEM_PROMPT = """
+<SystemPrompt>
+    <角色>
+        你是一名资深广告文案专家。你的任务是根据输入的一张广告图片中的文字内容,生成一句简洁有力的广告文案。
+    </角色>
+    <受众>
+        目标用户:50岁以上中老年人。语言需亲切、直白、易理解,避免专业术语与复杂长句。
+    </受众>
+
+    <结构公式>
+        [行动指令] + [低门槛/优惠承诺] + [核心价值/具体收益] + [紧迫感/稀缺性提醒]
+    </结构公式>
+
+    <约束>
+        1. 文案必须准确传达广告图片中的产品/服务信息,不得杜撰不存在的内容。
+        2. 加入紧迫感或稀缺性(如“限时”“名额有限”“马上领取”等),但不得虚构或夸大事实。
+        3. 避免医疗或功效的绝对化/保证性用语(如“治愈”“根治”“无副作用”“永久有效”)。
+        4. 不得包含违法、虚假、低俗、敏感、歧视性内容,不引导危险行为,不传播迷信。
+        5. 涉及健康/养生场景时,表述应为辅助/改善/建议性质,不承诺疗效;避免“祖传秘方”等违规表述。
+        6. 仅输出一句中文广告文案,简短醒目,适合作为宣传主标题。
+        7. 文案必须注意标点与短句分隔:动作、优惠承诺、核心收益之间用逗号分隔;紧迫感/稀缺性提醒用分号与前半部分隔开,避免把多个短语连写在一起,字数50字以内。
+    </约束>
+
+    <示例 few-shot="true">
+        长按二维码,0元入群,领取中医调理养生秘方;名额有限,赶快行动吧
+    </示例>
+
+    <输出要求>
+        始终通过工具调用(function calling)输出,参数仅包含生成的一句文案。
+    </输出要求>
+</SystemPrompt>
+"""
+
+tools: list[ChatCompletionToolParam] = [
+    {
+        "type": "function",
+        "function": {
+            "name": "generate_ocr_text",
+            "description": "生成一句适合中老年用户的广告文案(遵循结构公式与约束)",
+            "parameters": {
+                "type": "object",
+                "properties": {
+                    "ocr_text": {
+                        "type": "string",
+                        "description": "最终的一句广告文案(中文,简短醒目,合规)"
+                    }
+                },
+                "required": ["ocr_text"],
+                "additionalProperties": False
+            }
+        }
+    }
+]
+
+class UnderstandImageProvider:
+    print("UnderstandImageProvider called")
+    def understand_image(self, image_url: str, *, model: str) -> DataResponse:
+
+
+        client = OpenAI(
+            api_key = settings.dashscope_api_key or "",
+            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
+        )
+        if not client:
+            logger.error("OpenAI client is not initialized.")
+
+        completion = client.chat.completions.create(
+            model=model,
+            messages=[
+                {"role": "system", "content": SYSTEM_PROMPT},
+                {
+                    "role": "user", 
+                    "content": [{ "type": "image_url", "image_url": { "url": image_url } }],
+                },
+            ],
+            tools=tools,
+            tool_choice={
+                "type": "function",
+                "function": {"name": "generate_ocr_text"}
+            },
+            temperature=1.3
+        )
+
+        msg = completion.choices[0].message
+        # Safely parse tool call arguments (if any)
+        ocr_text = ""
+        try:
+            tool_calls = getattr(msg, "tool_calls", None) or []
+            if tool_calls:
+                call = tool_calls[0]
+                arg_str = getattr(getattr(call, "function", None), "arguments", None)
+                if isinstance(arg_str, str) and arg_str.strip():
+                    args = json.loads(arg_str)
+                    if isinstance(args, dict):
+                        ocr_text = str(args.get("ocr_text", "")).strip()
+        except Exception as e:
+            logger.error("parse tool call failed: %s", e, exc_info=True)
+
+        # Fallback: if no tool-calls returned, try to read text content
+        content = getattr(msg, "content", None)
+        if not ocr_text and isinstance(content, str):
+            ocr_text = content.strip()
+
+        print("✅ OCR_TEXT:\n", ocr_text)
+
+        return DataResponse(code=200, data=ocr_text, msg="Image understood successfully")
+    

+ 7 - 3
app/schemas/speech.py → app/schemas/base.py

@@ -28,12 +28,12 @@ class ChatResponse(BaseModel):
     model: Optional[str] = None
     usage: Optional[Usage] = None
 
-class DataPayload(BaseModel):
+class TextToSpeechPayload(BaseModel):
     audio_url: str
 
-class TextToSpeechResponse(BaseModel):
+class DataResponse(BaseModel):
     code: int
-    data: Optional[DataPayload] = None
+    data: object
     msg: Optional[str] = None
 
 class TextToSpeechRequest(BaseModel):
@@ -43,3 +43,7 @@ class TextToSpeechRequest(BaseModel):
     filename: str
     text: str
     model: str
+
+class UnderstandImageRequest(BaseModel):
+    image_url: str
+    model: str

+ 2 - 2
app/services/speech_service.py

@@ -1,4 +1,4 @@
-from ..schemas.speech import TextToSpeechRequest, TextToSpeechResponse
+from ..schemas.base import TextToSpeechRequest, DataResponse
 from ..providers.speech_provider import SpeechProvider
 
 
@@ -6,7 +6,7 @@ class SpeechService:
     def __init__(self, provider: SpeechProvider) -> None:
         self._provider = provider
 
-    def text_to_speech(self, req: TextToSpeechRequest) -> TextToSpeechResponse:
+    def text_to_speech(self, req: TextToSpeechRequest) -> DataResponse:
         return self._provider.text_to_speech(
             req.volume,
             req.pitch,

+ 12 - 0
app/services/vl_service.py

@@ -0,0 +1,12 @@
+from ..schemas.base import DataResponse, UnderstandImageRequest
+from ..providers.understand_image_provider import UnderstandImageProvider
+
+class VLService:
+    def __init__(self, provider: UnderstandImageProvider) -> None:
+        self._provider = provider
+
+    def understand_image(self, req: UnderstandImageRequest) -> DataResponse:
+        return self._provider.understand_image(
+            req.image_url,
+            model = req.model
+        )

+ 2 - 0
requirements.txt

@@ -4,3 +4,5 @@ uvicorn[standard]>=0.30.0
 dashscope>=0.1.0
 python-dotenv>=1.0.1
 requests>=2.31.0
+openai==1.107.2
+httpx[socks]