Browse Source

Merge branch 'feature/opti-generate-bgm-prompt' of Web/pq-web-ai into master

huangzhichao 1 month ago
parent
commit
e9e5681989

+ 6 - 0
app/api/deps.py

@@ -4,6 +4,8 @@ from ..providers.speech_provider import SpeechProvider
 from ..providers.understand_image_provider import UnderstandImageProvider    
 from ..services.speech_service import SpeechService
 from ..services.vl_service import VLService
+from ..services.evaluation_service import EvaluationService
+from ..services.evaluation_service import EvaluationProvider
 
 
 def get_speech_service() -> SpeechService:
@@ -15,3 +17,7 @@ def get_understand_image_service() -> VLService:
     provider = UnderstandImageProvider()
     return VLService(provider)
 
+def get_copywriting_evaluation_service() -> EvaluationService:
+    provider = EvaluationProvider()
+    return EvaluationService(provider)
+

+ 8 - 5
app/api/routes.py

@@ -1,13 +1,12 @@
 from fastapi import APIRouter, Depends
-from .deps import get_speech_service, get_understand_image_service
-from ..schemas.base import DataResponse, TextToSpeechRequest, UnderstandImageRequest
+from .deps import get_speech_service, get_understand_image_service, get_copywriting_evaluation_service
+from ..schemas.base import DataResponse, TextToSpeechRequest, UnderstandImageRequest, CopywritingEvaluationRequest
 from ..services.speech_service import SpeechService
 from ..services.vl_service import VLService
-
+from ..services.evaluation_service import EvaluationService
 
 router = APIRouter()
 
-
 @router.get("/ping", tags=["default"])
 def ping():
     return {"message": "pong"}
@@ -17,7 +16,11 @@ def text_to_speech(req: TextToSpeechRequest, service: SpeechService = Depends(ge
     return service.text_to_speech(req)
 
 
-@router.post('/llm/understand-image', response_model=DataResponse, tags=["understand-image"])
+@router.post('/llm/understand-image', response_model=DataResponse, tags=["llm"])
 def understand_image(req: UnderstandImageRequest, service: VLService = Depends(get_understand_image_service)):
     return service.understand_image(req)
 
+@router.post('/llm/copywriting-evaluation', response_model=DataResponse, tags=["llm"])
+def copywriting_evaluation(req: CopywritingEvaluationRequest, service: EvaluationService = Depends(get_copywriting_evaluation_service)):
+    return service.copywriting_evaluation(req)
+

+ 128 - 0
app/providers/evaluation_provider.py

@@ -0,0 +1,128 @@
+from openai import OpenAI
+from ..schemas.base import DataResponse, CopywritingEvaluationPayload
+from ..core.config import get_settings
+from ..core.logger import get_logger
+from openai.types.chat import ChatCompletionToolParam
+import json
+
+settings = get_settings()
+logger = get_logger("evaluation_provider")
+
+SYSTEM_PROMPT = """
+<SystemPrompt>
+    <角色>
+        你是一名广告文案质检专家。你的任务是:根据输入的广告图片文字(OCR结果)和生成的广告文案,仅从“格式”和“内容一致性”两个维度判断文案是否合格。
+    </角色>
+
+    <校验标准>
+        1. 格式要求:
+            - 文案必须以「[行动指令],[低门槛/优惠承诺]」连续开头。
+            - 行动指令示例:“长按二维码”“扫码二维码”“识别二维码”“长按识别”。
+            - 低门槛/优惠承诺示例:“0元入群”“免费进群”“0元加入”“限时免费加入”。
+            - 开头后应包含核心价值(如领取/获取/享受 + 方案/建议/课程/资料等)。
+            - 结尾应包含紧迫感/稀缺性提醒(如“名额有限”“限时”“赶快行动”)。
+            - 标点要求:动作、优惠、收益之间用逗号;紧迫提醒前用分号。
+            - 全句 ≤ 50 字。
+        2. 内容一致性要求:
+            - 文案内容必须与广告图片文字(OCR结果)一致。
+            - 优惠、产品/服务、动作入口等必须能在图片中找到对应信息。
+            - 不得凭空捏造图片中没有的要素。
+    </校验标准>
+
+    <判定逻辑>
+        - 符合以上两条 → pass=true,reason=""。
+        - 若仅轻微偏差(如分号缺失、字数略超) → pass=true,reason="建议优化:…"。
+        - 若明显不符合格式或文案与图片内容不一致 → pass=false,reason=简要说明。
+    </判定逻辑>
+
+    <输出要求>
+        始终调用函数 check_ad_copy,输出格式如下:
+        {
+          "pass": true/false,
+          "reason": "若不通过写原因;若通过则为空字符串或给出优化建议"
+        }
+    </输出要求>
+</SystemPrompt>
+"""
+
+tools: list[ChatCompletionToolParam] = [
+  {
+    "type": "function",
+    "function": {
+      "name": "check_ad_copy",
+      "description": "校验广告文案格式与内容是否合格",
+      "parameters": {
+        "type": "object",
+        "properties": {
+          "pass": {
+            "type": "boolean",
+            "description": "文案是否合格:true 表示通过,false 表示不通过"
+          },
+          "reason": {
+            "type": "string",
+            "description": "若不通过,说明原因;若通过则为空字符串"
+          }
+        },
+        "required": ["pass", "reason"],
+        "additionalProperties": False
+      }
+    }
+  }
+]
+
+class EvaluationProvider:
+  print("EvaluationProvider called")
+
+  def copywriting_evaluation(self, image_url: str, text: str, model: str) -> DataResponse:
+        client = OpenAI(
+            api_key = settings.dashscope_api_key or "",
+            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
+        )
+
+        if not client:
+            logger.error("OpenAI client is not initialized.")
+            return DataResponse(code=1, data=None, msg=f"OpenAI client is not initialized")
+        
+        completion = client.chat.completions.create(
+            model=model,
+            messages=[
+                {"role": "system", "content": SYSTEM_PROMPT},
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image_url", "image_url": {"url": image_url}},
+                        {"type": "text", "text": text}
+                    ],
+                },
+            ],
+            tools=tools,
+            tool_choice={
+                "type": "function",
+                "function": {"name": "check_ad_copy"}
+            },
+            temperature=0.3
+        )
+
+        msg = completion.choices[0].message
+        print(msg)
+        # Safely parse tool call arguments (if any)
+        content = True
+        reason = ""
+        try:
+            tool_calls = getattr(msg, "tool_calls", None) or []
+            if tool_calls:
+                call = tool_calls[0]
+                arg_str = getattr(getattr(call, "function", None), "arguments", None)
+                if isinstance(arg_str, str) and arg_str.strip():
+                    args = json.loads(arg_str)
+                    if isinstance(args, dict):
+                        content = bool(args.get("pass", True))
+                        reason = str(args.get("reason", "")).strip()
+        except Exception as e:
+            logger.error("parse tool call failed: %s", e, exc_info=True)
+            return DataResponse(code=1, data=None, msg=f"parse tool call failed: {e}")
+
+        print("✅ PASS:\n", content)
+        print("✅ REASON:\n", reason)
+
+        return DataResponse(code=0, data=CopywritingEvaluationPayload(content=content, reason=reason), msg="success")

+ 9 - 0
app/schemas/base.py

@@ -31,6 +31,10 @@ class ChatResponse(BaseModel):
 class TextToSpeechPayload(BaseModel):
     audio_url: str
 
+class CopywritingEvaluationPayload(BaseModel):
+    content: bool
+    reason: str
+
 class DataResponse(BaseModel):
     code: int
     data: object
@@ -46,4 +50,9 @@ class TextToSpeechRequest(BaseModel):
 
 class UnderstandImageRequest(BaseModel):
     image_url: str
+    model: str
+
+class CopywritingEvaluationRequest(BaseModel):
+    image_url: str
+    text: str
     model: str

+ 13 - 0
app/services/evaluation_service.py

@@ -0,0 +1,13 @@
+from ..schemas.base import DataResponse, CopywritingEvaluationRequest
+from ..providers.evaluation_provider import EvaluationProvider
+
+class EvaluationService:
+    def __init__(self, provider: EvaluationProvider) -> None:
+        self._provider = provider
+
+    def copywriting_evaluation(self, req: CopywritingEvaluationRequest) -> DataResponse:
+        return self._provider.copywriting_evaluation(
+            req.image_url,
+            req.text,
+            req.model
+        )