huangzhichao 2 недель назад
Родитель
Сommit
1b967a043a
6 измененных файлов с 290 добавлено и 11 удалено
  1. 2 0
      .gitignore
  2. 5 2
      app/api/routes.py
  3. 252 5
      app/providers/understand_image_provider.py
  4. 14 1
      app/schemas/base.py
  5. 8 2
      app/services/vl_service.py
  6. 9 1
      start.sh

+ 2 - 0
.gitignore

@@ -21,3 +21,5 @@ temp/
 # logs
 logs/
 */log
+
+note

+ 5 - 2
app/api/routes.py

@@ -1,6 +1,6 @@
 from fastapi import APIRouter, Depends
 from .deps import get_speech_service, get_understand_image_service, get_copywriting_evaluation_service
-from ..schemas.base import DataResponse, TextToSpeechRequest, UnderstandImageRequest, CopywritingEvaluationRequest
+from ..schemas.base import DataResponse, TextToSpeechRequest, UnderstandImageRequest, CopywritingEvaluationRequest, BusinessLicenseExtractRequest
 from ..services.speech_service import SpeechService
 from ..services.vl_service import VLService
 from ..services.evaluation_service import EvaluationService
@@ -20,7 +20,10 @@ def text_to_speech(req: TextToSpeechRequest, service: SpeechService = Depends(ge
 def understand_image(req: UnderstandImageRequest, service: VLService = Depends(get_understand_image_service)):
     return service.understand_image(req)
 
+@router.post('/llm/extract-business-license', response_model=DataResponse, tags=["llm"])
+def extract_business_license(req: BusinessLicenseExtractRequest, service: VLService = Depends(get_understand_image_service)):
+    return service.extract_business_license(req)
+
 @router.post('/llm/copywriting-evaluation', response_model=DataResponse, tags=["llm"])
 def copywriting_evaluation(req: CopywritingEvaluationRequest, service: EvaluationService = Depends(get_copywriting_evaluation_service)):
     return service.copywriting_evaluation(req)
-

+ 252 - 5
app/providers/understand_image_provider.py

@@ -1,5 +1,5 @@
 from openai import OpenAI
-from ..schemas.base import DataResponse
+from ..schemas.base import DataResponse, BusinessLicensePayload
 from ..core.config import get_settings
 from ..core.logger import get_logger
 from openai.types.chat import ChatCompletionToolParam
@@ -7,6 +7,12 @@ import json
 
 settings = get_settings()
 logger = get_logger("understand_image_provider")
+FIELD_LABELS_ZH = {
+    "company_name": "公司名称",
+    "unified_social_credit_code": "统一社会信用代码",
+    "legal_representative": "法定代表人",
+    "business_address": "住所/经营场所",
+}
 
 SYSTEM_PROMPT = """
 <SystemPrompt>
@@ -73,7 +79,90 @@ SYSTEM_PROMPT = """
 </SystemPrompt>
 """
 
-tools: list[ChatCompletionToolParam] = [
+BUSINESS_LICENSE_SYSTEM_PROMPT = """
+<SystemPrompt>
+    <角色>
+        你是一名企业证照信息提取助手,负责从中国大陆营业执照图片中准确提取关键字段信息,并判断是否需要人工复核。
+    </角色>
+
+    <字段定义>
+        <字段>
+            <名称>company_name</名称>
+            <中文名称>公司名称</中文名称>
+            <描述>公司名称(位于营业执照上“名称”字段)</描述>
+        </字段>
+        <字段>
+            <名称>unified_social_credit_code</名称>
+            <中文名称>统一社会信用代码</中文名称>
+            <描述>
+                统一社会信用代码(位于营业执照左上角),包括了18位的主体内容,如有后缀内容(会以括号形式展示),须全部提取。
+            </描述>
+        </字段>
+        <字段>
+            <名称>legal_representative</名称>
+            <中文名称>法定代表人</中文名称>
+            <描述>法定代表人(营业执照上“法定代表人”字段)</描述>
+        </字段>
+        <字段>
+            <名称>business_address</名称>
+            <中文名称>住所/经营场所</中文名称>
+            <描述>住所(营业执照上“住所”字段,若无则使用“经营场所”)</描述>
+        </字段>
+        <字段>
+            <名称>need_manual_review</名称>
+            <描述>
+                是否需要人工复核(布尔值)。当识别结果不符合规则,存在异常时,设为 true,否则为 false。
+            </描述>
+        </字段>
+        <字段>
+            <名称>inaccurate_fields</名称>
+            <描述>
+                可能识别不准确的字段key数组。仅允许以下值:
+                "company_name"、"unified_social_credit_code"、"legal_representative"、"business_address"。
+                当 need_manual_review 为 false 时必须返回 []。
+            </描述>
+        </字段>
+        <字段>
+            <名称>inaccurate_fields_zh</名称>
+            <中文名称>可能不准确字段(中文)</中文名称>
+            <描述>
+                可能识别不准确的字段中文名称数组。仅允许以下值:
+                "公司名称"、"统一社会信用代码"、"法定代表人"、"住所/经营场所"。
+                当 need_manual_review 为 false 时必须返回 []。
+            </描述>
+        </字段>
+    </字段定义>
+
+    <约束>
+        <规则>1. 所有字段必须仅根据图像中可见内容提取,禁止补全、猜测或逻辑推断。</规则>
+        <规则>2. unified_social_credit_code:
+            a) 如有后缀,须完整保留括号后缀(如“(1-1)”);
+            b) 主体必须为18位字符,若不足18位或含有明显识别错误,应设 need_manual_review 为 true;
+        </规则>
+        <规则>3. 若无法识别某字段内容,应输出空字符串 "",不要用 null 或其他占位符。</规则>
+        <规则>4. 所有字段输出必须为 JSON 格式结构,字段命名需与定义一致,不含解释性文字或多余内容。</规则>
+        <规则>5. 当 need_manual_review=true 时,inaccurate_fields 必须给出至少一个可能不准确字段。</规则>
+    </约束>
+
+    <输出格式>
+        {
+            "company_name": "",
+            "unified_social_credit_code": "",
+            "legal_representative": "",
+            "business_address": "",
+            "need_manual_review": false|true,
+            "inaccurate_fields": [],
+            "inaccurate_fields_zh": []
+        }
+    </输出格式>
+
+    <输入说明>
+        输入是一张中国大陆营业执照图片,请依据图像内容提取字段并输出结构化结果。如识别不全,标记为需人工复核。
+    </输入说明>
+</SystemPrompt>
+"""
+
+copywriting_tools: list[ChatCompletionToolParam] = [
     {
         "type": "function",
         "function": {
@@ -94,15 +183,76 @@ tools: list[ChatCompletionToolParam] = [
     }
 ]
 
+business_license_tools: list[ChatCompletionToolParam] = [
+    {
+        "type": "function",
+        "function": {
+            "name": "extract_business_license_fields",
+            "description": "从营业执照提取公司名称、统一社会信用代码、法定代表人、住所/经营场所,并标记是否需要人工复核",
+            "parameters": {
+                "type": "object",
+                "properties": {
+                    "company_name": {
+                        "type": "string",
+                        "description": "公司名称"
+                    },
+                    "unified_social_credit_code": {
+                        "type": "string",
+                        "description": "统一社会信用代码,包含括号后缀"
+                    },
+                    "legal_representative": {
+                        "type": "string",
+                        "description": "法定代表人"
+                    },
+                    "business_address": {
+                        "type": "string",
+                        "description": "住所/经营场所(优先使用“住所”字段)"
+                    },
+                    "need_manual_review": {
+                        "type": "boolean",
+                        "description": "是否需要人工复核。当统一社会信用代码主体不足18位或识别异常时应为 true"
+                    },
+                    "inaccurate_fields": {
+                        "type": "array",
+                        "description": "可能识别不准确的字段名列表;当 need_manual_review 为 false 时返回空数组",
+                        "items": {
+                            "type": "string",
+                            "enum": [
+                                "company_name",
+                                "unified_social_credit_code",
+                                "legal_representative",
+                                "business_address"
+                            ]
+                        }
+                    }
+                },
+                "required": [
+                    "company_name",
+                    "unified_social_credit_code",
+                    "legal_representative",
+                    "business_address",
+                    "need_manual_review",
+                    "inaccurate_fields"
+                ],
+                "additionalProperties": False
+            }
+        }
+    }
+]
+
 class UnderstandImageProvider:
     print("UnderstandImageProvider called")
-    def understand_image(self, image_url: str, *, model: str) -> DataResponse:
 
-        client = OpenAI(
+    def _create_client(self) -> OpenAI:
+        return OpenAI(
             api_key = settings.dashscope_api_key or "",
             base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
         )
 
+    def understand_image(self, image_url: str, *, model: str) -> DataResponse:
+
+        client = self._create_client()
+
         if not client:
             logger.error("OpenAI client is not initialized.")
             return DataResponse(code=1, data=None, msg=f"OpenAI client is not initialized")
@@ -116,7 +266,7 @@ class UnderstandImageProvider:
                     "content": [{ "type": "image_url", "image_url": { "url": image_url } }],
                 },
             ],
-            tools=tools,
+            tools=copywriting_tools,
             tool_choice={
                 "type": "function",
                 "function": {"name": "generate_ocr_text"}
@@ -148,4 +298,101 @@ class UnderstandImageProvider:
         print("✅ OCR_TEXT:\n", ocr_text)
 
         return DataResponse(code=0, data=ocr_text, msg="success")
+
+    def extract_business_license(self, image_url: str, *, model: str) -> DataResponse:
+        client = self._create_client()
+
+        if not client:
+            logger.error("OpenAI client is not initialized.")
+            return DataResponse(code=1, data=None, msg="OpenAI client is not initialized")
+
+        completion = client.chat.completions.create(
+            model=model,
+            messages=[
+                {"role": "system", "content": BUSINESS_LICENSE_SYSTEM_PROMPT},
+                {
+                    "role": "user",
+                    "content": [{ "type": "image_url", "image_url": { "url": image_url } }],
+                },
+            ],
+            tools=business_license_tools,
+            tool_choice={
+                "type": "function",
+                "function": {"name": "extract_business_license_fields"}
+            },
+            temperature=0.2
+        )
+
+        msg = completion.choices[0].message
+        payload = BusinessLicensePayload(
+            company_name="",
+            unified_social_credit_code="",
+            legal_representative="",
+            business_address="",
+            need_manual_review=False,
+            inaccurate_fields=[],
+            inaccurate_fields_zh=[],
+        )
+        try:
+            tool_calls = getattr(msg, "tool_calls", None) or []
+            if tool_calls:
+                call = tool_calls[0]
+                arg_str = getattr(getattr(call, "function", None), "arguments", None)
+                if isinstance(arg_str, str) and arg_str.strip():
+                    args = json.loads(arg_str)
+                    if isinstance(args, dict):
+                        allowed_fields = {
+                            "company_name",
+                            "unified_social_credit_code",
+                            "legal_representative",
+                            "business_address",
+                        }
+                        raw_inaccurate_fields = args.get("inaccurate_fields", [])
+                        inaccurate_fields: list[str] = []
+                        if isinstance(raw_inaccurate_fields, list):
+                            inaccurate_fields = [
+                                str(field).strip()
+                                for field in raw_inaccurate_fields
+                                if str(field).strip() in allowed_fields
+                            ]
+
+                        need_manual_review = bool(args.get("need_manual_review", False))
+                        company_name = str(args.get("company_name", "")).strip()
+                        unified_social_credit_code = str(args.get("unified_social_credit_code", "")).strip()
+                        legal_representative = str(args.get("legal_representative", "")).strip()
+                        business_address = str(args.get("business_address", "")).strip()
+
+                        # Fallback for model omissions: if marked for review but no fields provided, infer likely problematic ones.
+                        if need_manual_review and not inaccurate_fields:
+                            if not company_name:
+                                inaccurate_fields.append("company_name")
+                            if not unified_social_credit_code or len(unified_social_credit_code) < 18:
+                                inaccurate_fields.append("unified_social_credit_code")
+                            if not legal_representative:
+                                inaccurate_fields.append("legal_representative")
+                            if not business_address:
+                                inaccurate_fields.append("business_address")
+                            if not inaccurate_fields:
+                                inaccurate_fields.append("unified_social_credit_code")
+
+                        if not need_manual_review:
+                            inaccurate_fields = []
+                        inaccurate_fields_zh = [
+                            FIELD_LABELS_ZH[field] for field in inaccurate_fields if field in FIELD_LABELS_ZH
+                        ]
+
+                        payload = BusinessLicensePayload(
+                            company_name=company_name,
+                            unified_social_credit_code=unified_social_credit_code,
+                            legal_representative=legal_representative,
+                            business_address=business_address,
+                            need_manual_review=need_manual_review,
+                            inaccurate_fields=inaccurate_fields,
+                            inaccurate_fields_zh=inaccurate_fields_zh,
+                        )
+        except Exception as e:
+            logger.error("parse business license tool call failed: %s", e, exc_info=True)
+            return DataResponse(code=1, data=None, msg=f"parse tool call failed: {e}")
+
+        return DataResponse(code=0, data=payload, msg="success")
     

+ 14 - 1
app/schemas/base.py

@@ -36,6 +36,15 @@ class CopywritingEvaluationPayload(BaseModel):
     reason: str
     corrected_msg: str
 
+class BusinessLicensePayload(BaseModel):
+    company_name: str
+    unified_social_credit_code: str
+    legal_representative: str
+    business_address: str
+    need_manual_review: bool
+    inaccurate_fields: List[str]
+    inaccurate_fields_zh: List[str]
+
 class DataResponse(BaseModel):
     code: int
     data: object
@@ -53,7 +62,11 @@ class UnderstandImageRequest(BaseModel):
     image_url: str
     model: str
 
+class BusinessLicenseExtractRequest(BaseModel):
+    image_url: str
+    model: str
+
 class CopywritingEvaluationRequest(BaseModel):
     image_url: str
     text: str
-    model: str
+    model: str

+ 8 - 2
app/services/vl_service.py

@@ -1,4 +1,4 @@
-from ..schemas.base import DataResponse, UnderstandImageRequest
+from ..schemas.base import DataResponse, UnderstandImageRequest, BusinessLicenseExtractRequest
 from ..providers.understand_image_provider import UnderstandImageProvider
 
 class VLService:
@@ -9,4 +9,10 @@ class VLService:
         return self._provider.understand_image(
             req.image_url,
             model = req.model
-        )
+        )
+
+    def extract_business_license(self, req: BusinessLicenseExtractRequest) -> DataResponse:
+        return self._provider.extract_business_license(
+            req.image_url,
+            model=req.model,
+        )

+ 9 - 1
start.sh

@@ -2,5 +2,13 @@
 
 echo "🚀 启动 Ai Server..."
 
+if [ ! -f "./.venv/bin/activate" ]; then
+  echo "❌ 未找到虚拟环境: ./.venv/bin/activate"
+  exit 1
+fi
+
+# 激活虚拟环境
+source ./.venv/bin/activate
+
 # 运行主程序
-uvicorn app.main:app --reload --port 8000
+uvicorn app.main:app --reload --port 8000