luojunhui 3 minggu lalu
induk
melakukan
ace86db8f7

+ 1 - 1
applications/api/embedding.py

@@ -1,5 +1,5 @@
 from applications.config import LOCAL_MODEL_CONFIG, VLLM_SERVER_URL, DEV_VLLM_SERVER_URL
-from applications.utils import AsyncHttpClient
+from applications.utils.http import AsyncHttpClient
 
 
 async def get_basic_embedding(text: str, model: str, dev=False):

+ 4 - 0
applications/async_task/__init__.py

@@ -0,0 +1,4 @@
+from .chunk_task import ChunkTask
+
+
+__all__ = ['ChunkTask']

+ 97 - 0
applications/async_task/chunk_task.py

@@ -0,0 +1,97 @@
+import asyncio
+import uuid
+from typing import List
+
+from applications.utils.mysql import ContentChunks, Contents
+from applications.utils.chunks import TopicAwareChunker, LLMClassifier
+from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
+
+
+class ChunkTask(TopicAwareChunker):
+    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
+        super().__init__(cfg)
+        self.content_chunk_processor = None
+        self.contents_processor = None
+        self.mysql_pool = mysql_pool
+        self.vector_pool = vector_pool
+        self.classifier = LLMClassifier()
+
+    def init_processer(self):
+        self.contents_processor = Contents(self.mysql_pool)
+        self.content_chunk_processor = ContentChunks(self.mysql_pool)
+
+    async def process_content(self, doc_id, text) -> List[Chunk]:
+        flag = await self.contents_processor.insert_content(doc_id, text)
+        if not flag:
+            return []
+        else:
+            raw_chunks = await self.chunk(text)
+            if not raw_chunks:
+                await self.contents_processor.update_content_status(
+                    doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.FAILED_STATUS
+                )
+                return []
+
+            affected_rows = await self.contents_processor.update_content_status(
+                doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
+            )
+            print(affected_rows)
+            return raw_chunks
+
+    async def process_each_chunk(self, chunk: Chunk):
+        # insert
+        flag = await self.content_chunk_processor.insert_chunk(chunk)
+        if not flag:
+            return
+
+        acquire_lock = await self.content_chunk_processor.update_chunk_status(
+            doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
+        )
+        if not acquire_lock:
+            return
+
+        completion = await self.classifier.classify_chunk(chunk)
+        if not completion:
+            await self.content_chunk_processor.update_chunk_status(
+                doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
+            )
+
+        update_flag = await self.content_chunk_processor.set_chunk_result(
+            chunk=completion, new_status=self.FINISHED_STATUS, ori_status=self.PROCESSING_STATUS
+        )
+        if not update_flag:
+            await self.content_chunk_processor.update_chunk_status(
+                doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
+            )
+
+
+    async def deal(self, data):
+        text = data.get("text")
+        if not text:
+            return None
+
+        self.init_processer()
+        doc_id = f"doc-{uuid.uuid4()}"
+
+        async def _process():
+            chunks = await self.process_content(doc_id, text)
+            if not chunks:
+                return
+
+            # 开始分batch
+            async with asyncio.TaskGroup() as tg:
+                for chunk in chunks:
+                    tg.create_task(self.process_each_chunk(chunk))
+
+            await self.contents_processor.update_content_status(
+                doc_id=doc_id, ori_status=self.PROCESSING_STATUS, new_status=self.FINISHED_STATUS
+            )
+
+        await _process()
+        # asyncio.create_task(_process())
+        return doc_id
+
+
+
+
+

+ 6 - 1
applications/config/__init__.py

@@ -5,7 +5,9 @@ from .model_config import (
     DEV_VLLM_SERVER_URL,
 )
 from .deepseek_config import DEEPSEEK_MODEL, DEEPSEEK_API_KEY
-from .base_chunk import Chunk
+from .base_chunk import Chunk, ChunkerConfig
+from .milvus_config import MILVUS_CONFIG
+from .mysql_config import RAG_MYSQL_CONFIG
 
 __all__ = [
     "DEFAULT_MODEL",
@@ -15,4 +17,7 @@ __all__ = [
     "DEEPSEEK_MODEL",
     "DEEPSEEK_API_KEY",
     "Chunk",
+    "ChunkerConfig",
+    "MILVUS_CONFIG",
+    "RAG_MYSQL_CONFIG"
 ]

+ 11 - 0
applications/config/base_chunk.py

@@ -16,3 +16,14 @@ class Chunk:
     concepts: List[str] = field(default_factory=list)
     questions: List[str] = field(default_factory=list)
 
+
+@dataclass
+class ChunkerConfig:
+    target_tokens: int = 256
+    boundary_threshold: float = 0.8
+    min_sent_per_chunk: int = 3
+    max_sent_per_chunk: int = 10
+    enable_adaptive_boundary: bool = True
+    enable_kg: bool = True
+    topic_purity_floor: float = 0.8
+    kg_topk: int = 3

+ 8 - 0
applications/config/milvus_config.py

@@ -0,0 +1,8 @@
+
+MILVUS_CONFIG = {
+    # "host": "c-981be0ee7225467b-internal.milvus.aliyuncs.com", # 内网
+    "host": "c-981be0ee7225467b.milvus.aliyuncs.com", # 公网
+    "user": "root",
+    "password": "Piaoquan@2025",
+    "port": "19530"
+}

+ 10 - 0
applications/config/mysql_config.py

@@ -0,0 +1,10 @@
+RAG_MYSQL_CONFIG = {
+    "host": "rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com",
+    "user": "wqsd",
+    "password": "wqsd@2025",
+    "port": 3306,
+    "db": "rag",
+    "charset": "utf8mb4",
+    "minsize": 5,
+    "maxsize": 20,
+}

+ 0 - 8
applications/utils/__init__.py

@@ -1,8 +0,0 @@
-from .chunks import TopicAwareChunker
-from .http import AsyncHttpClient
-from .nlp import SplitTextIntoSentences
-from .nlp import detect_language
-from .nlp import num_tokens
-from .milvus import milvus_collection
-
-__all__ = ["AsyncHttpClient", "SplitTextIntoSentences", "detect_language", "num_tokens", "TopicAwareChunker"]

+ 3 - 1
applications/utils/chunks/__init__.py

@@ -1,5 +1,7 @@
 from .topic_aware_chunking import TopicAwareChunker
+from .llm_classifier import LLMClassifier
 
 __all__ = [
     "TopicAwareChunker",
-]
+    "LLMClassifier",
+]

+ 11 - 12
applications/utils/chunks/llm_classifier.py

@@ -13,11 +13,10 @@ class LLMClassifier:
 1. **主题标签 (topic)**:一句话概括文本主题  
 2. **关键词 (keywords)**:3-5 个,便于检索  
 3. **摘要 (summary)**:50字以内简要说明  
-4. **特征 (features)**:包含以下子字段
-   - domain: 该文本所属领域(如:AI 技术、体育、金融)  
-   - task_type: 文本主要任务类型(如:解释、教学、动作描述、方法提出)  
-   - concepts: 涉及的核心知识点或概念  
-   - questions: 文本中隐含或显式的问题  
+4. **领域 (domain)**:该文本所属领域(如:AI 技术、体育、金融)
+5. **任务类型 (task_type)**:文本主要任务类型(如:解释、教学、动作描述、方法提出)  
+6. **核心知识点 (concepts)**:涉及的核心知识点或概念  
+7. **显示/隐式问题 (questions)**:文本中隐含或显式的问题  
 
 请用 JSON 格式输出,例如:
 {
@@ -46,13 +45,13 @@ class LLMClassifier:
             text=text,
             tokens=chunk.tokens,
             topic_purity=chunk.topic_purity,
-            summary=response["summary"],
-            topic=response["topic"],
-            domain=response["domain"],
-            task_type=response["task_type"],
-            concepts=response["concepts"],
-            keywords=response["keywords"],
-            questions=response["questions"],
+            summary=response.get("summary"),
+            topic=response.get("topic"),
+            domain=response.get("domain"),
+            task_type=response.get("task_type"),
+            concepts=response.get("concepts", []),
+            keywords=response.get("keywords", []),
+            questions=response.get("questions", []),
         )
 
     async def classify_chunk_by_topic(self, chunk_list: List[Chunk]) -> List[Chunk]:

+ 12 - 18
applications/utils/chunks/topic_aware_chunking.py

@@ -5,28 +5,17 @@
 from __future__ import annotations
 
 import re, uuid
-from dataclasses import dataclass
+import time
 from typing import List
 
 import numpy as np
 from sklearn.preprocessing import minmax_scale
 
-from applications.utils import SplitTextIntoSentences, num_tokens
 from applications.api import get_basic_embedding
-from applications.config import DEFAULT_MODEL, Chunk
-from applications.utils.chunks.llm_classifier import LLMClassifier
+from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
+from applications.utils.nlp import SplitTextIntoSentences, num_tokens
 
-
-@dataclass
-class ChunkerConfig:
-    target_tokens: int = 256
-    boundary_threshold: float = 0.8
-    min_sent_per_chunk: int = 3
-    max_sent_per_chunk: int = 10
-    enable_adaptive_boundary: bool = True
-    enable_kg: bool = True
-    topic_purity_floor: float = 0.8
-    kg_topk: int = 3
+# from .llm_classifier import LLMClassifier
 
 
 # sentence boundary strategy
@@ -100,9 +89,15 @@ class BoundaryDetector:
 
 
 class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
+
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
     def __init__(self, cfg: ChunkerConfig):
         super().__init__(cfg)
-        self.classifier = LLMClassifier()
+        # self.classifier = LLMClassifier()
         self.doc_id = f"doc-{uuid.uuid4()}"
 
     @staticmethod
@@ -177,8 +172,7 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         sentences_embeddings = await self._encode_batch(sentence_list)
         boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
         raw_chunks = self._pack_by_boundaries(sentence_list, boundaries)
-        final_chunks = await self.classifier.classify_chunk_by_topic(raw_chunks)
-        return final_chunks
+        return raw_chunks
 
 
 # async def main():

+ 2 - 1
applications/utils/milvus/__init__.py

@@ -1,4 +1,5 @@
 from .collection import milvus_collection
+from .functions import async_insert_chunk, async_search_chunk
 
 
-__all__ = ["milvus_collection"]
+__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk"]

+ 6 - 1
applications/utils/milvus/collection.py

@@ -1,12 +1,17 @@
 from pymilvus import connections, CollectionSchema, Collection
 from applications.utils.milvus.field import fields
+from applications.config import MILVUS_CONFIG
 
 
-connections.connect("default", host="localhost", port="19530")
+connections.connect("default", **MILVUS_CONFIG)
+
 schema = CollectionSchema(
     fields, description="Chunk multi-vector embeddings with metadata"
 )
 milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
 
 
+print("Connecting to Milvus Server...successfully")
+
+
 __all__ = ["milvus_collection"]

+ 29 - 6
applications/utils/milvus/field.py

@@ -2,34 +2,57 @@ from pymilvus import FieldSchema, DataType
 
 # milvus 向量数据库
 fields = [
-    FieldSchema(name="chunk_id", dtype=DataType.INT64, is_primary=True, auto_id=False),
-    FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
+    FieldSchema(
+        name="id",
+        dtype=DataType.INT64,
+        is_primary=True,
+        auto_id=True,
+        description="自增逐渐id",
+    ),
+    FieldSchema(
+        name="doc_id", dtype=DataType.VARCHAR, max_length=64, description="文档id"
+    ),
+    FieldSchema(name="chunk_id", dtype=DataType.INT64, description="文档分块id"),
     # 三种向量字段
     FieldSchema(name="vector_text", dtype=DataType.FLOAT_VECTOR, dim=2560),
     FieldSchema(name="vector_summary", dtype=DataType.FLOAT_VECTOR, dim=2560),
     FieldSchema(name="vector_questions", dtype=DataType.FLOAT_VECTOR, dim=2560),
     # metadata
-    FieldSchema(name="topic", dtype=DataType.VARCHAR, max_length=255),
-    FieldSchema(name="domain", dtype=DataType.VARCHAR, max_length=100),
-    FieldSchema(name="task_type", dtype=DataType.VARCHAR, max_length=100),
-    FieldSchema(name="summary", dtype=DataType.VARCHAR, max_length=512),
+    FieldSchema(
+        name="topic", dtype=DataType.VARCHAR, max_length=255, description="主题"
+    ),
+    FieldSchema(
+        name="domain", dtype=DataType.VARCHAR, max_length=100, description="领域"
+    ),
+    FieldSchema(
+        name="task_type", dtype=DataType.VARCHAR, max_length=100, description="任务类型"
+    ),
+    FieldSchema(
+        name="summary", dtype=DataType.VARCHAR, max_length=512, description="总结"
+    ),
     FieldSchema(
         name="keywords",
         dtype=DataType.ARRAY,
         element_type=DataType.VARCHAR,
         max_length=100,
+        max_capacity=5,
+        description="关键词",
     ),
     FieldSchema(
         name="concepts",
         dtype=DataType.ARRAY,
         element_type=DataType.VARCHAR,
         max_length=100,
+        max_capacity=5,
+        description="主要知识点",
     ),
     FieldSchema(
         name="questions",
         dtype=DataType.ARRAY,
         element_type=DataType.VARCHAR,
         max_length=200,
+        max_capacity=5,
+        description="隐含问题",
     ),
     FieldSchema(name="topic_purity", dtype=DataType.FLOAT),
     FieldSchema(name="tokens", dtype=DataType.INT64),

+ 33 - 0
applications/utils/milvus/functions.py

@@ -0,0 +1,33 @@
+import asyncio
+from typing import Dict
+
+import pymilvus
+
+
+async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
+    """
+    :param collection:
+    :param data: insert data
+    :return:
+    """
+    return await asyncio.to_thread(collection.insert, [data])
+
+
+async def async_search_chunk(
+    collection: pymilvus.Collection, query_vector, params: Dict
+):
+    """
+    :param query_vector: query 向量
+    :param collection:
+    :param params: search 参数
+    :return:
+    """
+    expr = None
+    return await asyncio.to_thread(
+        collection.search,
+        data=[query_vector],
+        param={"metric_type": "COSINE", "params": {"nprobe": 10}},
+        limit=params["limit"],
+        anns_field="vector_text",
+        expr=expr,
+    )

+ 11 - 0
applications/utils/mysql/__init__.py

@@ -0,0 +1,11 @@
+from .pool import DatabaseManager
+from .mapper import Contents, ContentChunks
+
+# 全局数据库管理器实例
+mysql_manager = DatabaseManager()
+
+__all__ = [
+    "mysql_manager",
+    "Contents",
+    "ContentChunks",
+]

+ 91 - 0
applications/utils/mysql/mapper.py

@@ -0,0 +1,91 @@
+import json
+from applications.config import Chunk
+
+
+class TaskConst:
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
+
+class BaseMySQLClient:
+
+    def __init__(self, pool):
+        self.pool = pool
+
+
+class Contents(BaseMySQLClient):
+
+    async def insert_content(self, doc_id, text):
+        query = """
+            INSERT IGNORE INTO contents
+                (doc_id, text)
+            VALUES (%s, %s);
+        """
+        return await self.pool.async_save(query=query, params=(doc_id, text))
+
+    async def update_content_status(self, doc_id, ori_status, new_status):
+        query = """
+            UPDATE contents
+            SET status = %s
+            WHERE doc_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, ori_status)
+        )
+
+
+class ContentChunks(BaseMySQLClient):
+
+    async def insert_chunk(self, chunk: Chunk) -> int:
+        query = """
+            INSERT IGNORE INTO content_chunks
+                (chunk_id, doc_id, text, tokens, topic_purity) 
+                VALUES (%s, %s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.chunk_id,
+                chunk.doc_id,
+                chunk.text,
+                chunk.tokens,
+                chunk.topic_purity,
+            ),
+        )
+
+    async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET chunk_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET embedding_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id,  ori_status)
+        )
+
+    async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, keywords = %s, questions = %s, chunk_status = %s
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.summary, chunk.topic, chunk.domain, chunk.task_type,
+                json.dumps(chunk.concepts), json.dumps(chunk.keywords), json.dumps(chunk.questions), new_status,
+                chunk.doc_id, chunk.chunk_id, ori_status
+            )
+        )

+ 80 - 0
applications/utils/mysql/pool.py

@@ -0,0 +1,80 @@
+from aiomysql import create_pool
+from aiomysql.cursors import DictCursor
+from applications.config import RAG_MYSQL_CONFIG
+
+
+class DatabaseManager:
+    def __init__(self):
+        self.databases = None
+        self.pools = {}
+
+    async def init_pools(self):
+        # 从配置获取数据库配置,也可以直接在这里配置
+        self.databases = {"rag": RAG_MYSQL_CONFIG}
+
+        for db_name, config in self.databases.items():
+            try:
+                pool = await create_pool(
+                    host=config["host"],
+                    port=config["port"],
+                    user=config["user"],
+                    password=config["password"],
+                    db=config["db"],
+                    minsize=config["minsize"],
+                    maxsize=config["maxsize"],
+                    cursorclass=DictCursor,
+                    autocommit=True,
+                )
+                self.pools[db_name] = pool
+                print(f"Created connection pool for {db_name}")
+            except Exception as e:
+                print(f"Failed to create pool for {db_name}: {str(e)}")
+                self.pools[db_name] = None
+
+    async def close_pools(self):
+        for name, pool in self.pools.items():
+            if pool:
+                pool.close()
+                await pool.wait_closed()
+
+    async def async_fetch(
+        self, query, db_name="rag", params=None, cursor_type=DictCursor
+    ):
+        pool = self.pools[db_name]
+        if not pool:
+            await self.init_pools()
+        # fetch from db
+        try:
+            async with pool.acquire() as conn:
+                async with conn.cursor(cursor_type) as cursor:
+                    await cursor.execute(query, params)
+                    fetch_response = await cursor.fetchall()
+
+            return fetch_response
+        except Exception as e:
+            return None
+
+    async def async_save(self, query, params, db_name="rag", batch: bool = False):
+        pool = self.pools[db_name]
+        if not pool:
+            await self.init_pools()
+
+        async with pool.acquire() as connection:
+            async with connection.cursor() as cursor:
+                try:
+                    if batch:
+                        await cursor.executemany(query, params)
+                    else:
+                        await cursor.execute(query, params)
+                    affected_rows = cursor.rowcount
+                    await connection.commit()
+                    return affected_rows
+                except Exception as e:
+                    await connection.rollback()
+                    raise e
+
+    def get_pool(self, db_name):
+        return self.pools.get(db_name)
+
+    def list_databases(self):
+        return list(self.databases.keys())

+ 6 - 7
routes/buleprint.py

@@ -1,14 +1,13 @@
 from quart import Blueprint, jsonify, request
 
-from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG
+from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
 from applications.api import get_basic_embedding
-from applications.utils import TopicAwareChunker
-from applications.utils.chunks.topic_aware_chunking import ChunkerConfig
+from applications.async_task import ChunkTask
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
 
-def server_routes(vector_db):
+def server_routes(mysql_db, vector_db):
 
     @server_bp.route("/embed", methods=["POST"])
     async def embed():
@@ -28,9 +27,9 @@ def server_routes(vector_db):
         if not text:
             return jsonify({"error": "error  text"})
 
-        tpc = TopicAwareChunker(cfg=ChunkerConfig())
-        chunks = await tpc.chunk(text)
-        return jsonify({"chunks": chunks})
+        chunk_task = ChunkTask(mysql_db, vector_db, cfg=ChunkerConfig())
+        doc_id = await chunk_task.deal(body)
+        return jsonify({"doc_id": doc_id})
 
     @server_bp.route("/search", methods=["POST"])
     async def search():

+ 19 - 10
vector_app.py

@@ -1,21 +1,30 @@
+import jieba
 from quart import Quart
-from quart_cors import cors
-
-# from pymilvus import connections
 
 from applications.config import LOCAL_MODEL_CONFIG, DEFAULT_MODEL
+from applications.utils.milvus import milvus_collection
+from applications.utils.mysql import mysql_manager
 from routes import server_routes
 
 app = Quart(__name__)
 
 MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
 
-
-# 连接向量数据库
-# connections.connect("default", host="milvus", port="19530")
-# connections.connect("default", host="milvus", port="19530")
-connections = None
-
 # 注册路由
-app_route = server_routes(connections)
+app_route = server_routes(mysql_manager, milvus_collection)
 app.register_blueprint(app_route)
+
+@app.before_serving
+async def startup():
+    print("Starting application...")
+    await mysql_manager.init_pools()
+    print("Mysql pools init successfully")
+
+    print("Loading jieba dictionary...")
+    jieba.initialize()
+    print("Jieba dictionary loaded successfully")
+
+@app.after_serving
+async def shutdown():
+    print("Shutting down application...")
+    await mysql_manager.close_pools()