luojunhui преди 3 седмици
родител
ревизия
3d4d2e64b1

+ 15 - 7
applications/async_task/chunk_task.py

@@ -11,8 +11,8 @@ from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
 
 
 class ChunkEmbeddingTask(TopicAwareChunker):
-    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
-        super().__init__(cfg)
+    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
+        super().__init__(cfg, doc_id)
         self.content_chunk_processor = None
         self.contents_processor = None
         self.mysql_pool = mysql_pool
@@ -52,6 +52,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
         # insert
         flag = await self.content_chunk_processor.insert_chunk(chunk)
         if not flag:
+            print("插入文本失败")
             return
 
         acquire_lock = await self.content_chunk_processor.update_chunk_status(
@@ -61,6 +62,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
             new_status=self.PROCESSING_STATUS,
         )
         if not acquire_lock:
+            print("抢占文本分块锁失败")
             return
 
         completion = await self.classifier.classify_chunk(chunk)
@@ -71,6 +73,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
                 ori_status=self.PROCESSING_STATUS,
                 new_status=self.FAILED_STATUS,
             )
+            print("从deepseek获取信息失败")
             return
 
         update_flag = await self.content_chunk_processor.set_chunk_result(
@@ -118,6 +121,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
                 "task_type": chunk.task_type,
                 "summary": chunk.summary,
                 "keywords": chunk.keywords,
+                "entities": chunk.entities,
                 "concepts": chunk.concepts,
                 "questions": chunk.questions,
                 "topic_purity": chunk.topic_purity,
@@ -140,18 +144,22 @@ class ChunkEmbeddingTask(TopicAwareChunker):
             print(f"存入向量数据库失败", e)
 
     async def deal(self, data):
-        text = data.get("text")
+        text = data.get("text", "")
+        text = text.strip()
         if not text:
             return None
 
         self.init_processer()
-        doc_id = f"doc-{uuid.uuid4()}"
 
         async def _process():
-            chunks = await self.process_content(doc_id, text)
+            chunks = await self.process_content(self.doc_id, text)
             if not chunks:
                 return
 
+            # # dev
+            # for chunk in chunks:
+            #     await self.process_each_chunk(chunk)
+
             await run_tasks_with_asyncio_task_group(
                 task_list=chunks,
                 handler=self.process_each_chunk,
@@ -161,10 +169,10 @@ class ChunkEmbeddingTask(TopicAwareChunker):
             )
 
             await self.contents_processor.update_content_status(
-                doc_id=doc_id,
+                doc_id=self.doc_id,
                 ori_status=self.PROCESSING_STATUS,
                 new_status=self.FINISHED_STATUS,
             )
 
         asyncio.create_task(_process())
-        return doc_id
+        return self.doc_id

+ 1 - 0
applications/config/base_chunk.py

@@ -15,6 +15,7 @@ class Chunk:
     keywords: List[str] = field(default_factory=list)
     concepts: List[str] = field(default_factory=list)
     questions: List[str] = field(default_factory=list)
+    entities: List[str] = field(default_factory=list)
 
 
 @dataclass

+ 6 - 5
applications/utils/chunks/llm_classifier.py

@@ -16,7 +16,8 @@ class LLMClassifier:
 4. **领域 (domain)**:该文本所属领域(如:AI 技术、体育、金融)
 5. **任务类型 (task_type)**:文本主要任务类型(如:解释、教学、动作描述、方法提出)  
 6. **核心知识点 (concepts)**:涉及的核心知识点或概念  
-7. **显示/隐式问题 (questions)**:文本中隐含或显式的问题  
+7. **显示/隐式问题 (questions)**:文本中隐含或显式的问题
+8. **实体(entities)**: 文本中的提到的命名实体
 
 请用 JSON 格式输出,例如:
 {
@@ -27,6 +28,7 @@ class LLMClassifier:
     "keywords": ["RAG", "检索增强", "文本分块", "知识图谱"],
     "concepts": ["RAG", "文本分块", "知识图谱"],
     "questions": ["如何提升RAG的检索效果?"]
+    "entities": ["entity1"]
 }
 
 下面是文本:
@@ -39,6 +41,7 @@ class LLMClassifier:
         response = await fetch_deepseek_completion(
             model="DeepSeek-V3", prompt=prompt, output_type="json"
         )
+        print(response)
         return Chunk(
             chunk_id=chunk.chunk_id,
             doc_id=chunk.doc_id,
@@ -52,7 +55,5 @@ class LLMClassifier:
             concepts=response.get("concepts", []),
             keywords=response.get("keywords", []),
             questions=response.get("questions", []),
-        )
-
-    async def classify_chunk_by_topic(self, chunk_list: List[Chunk]) -> List[Chunk]:
-        return [await self.classify_chunk(chunk) for chunk in chunk_list]
+            entities=response.get("entities", []),
+        )

+ 3 - 4
applications/utils/chunks/topic_aware_chunking.py

@@ -4,8 +4,7 @@
 
 from __future__ import annotations
 
-import re, uuid
-import time
+import re
 from typing import List
 
 import numpy as np
@@ -95,10 +94,10 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
     FINISHED_STATUS = 2
     FAILED_STATUS = 3
 
-    def __init__(self, cfg: ChunkerConfig):
+    def __init__(self, cfg: ChunkerConfig, doc_id: str):
         super().__init__(cfg)
         # self.classifier = LLMClassifier()
-        self.doc_id = f"doc-{uuid.uuid4()}"
+        self.doc_id = doc_id
 
     @staticmethod
     async def _encode_batch(texts: List[str]) -> np.ndarray:

+ 12 - 1
applications/utils/milvus/collection.py

@@ -10,8 +10,19 @@ schema = CollectionSchema(
 )
 milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
 
+# create index
+vector_index_params = {
+    "index_type": "IVF_FLAT",
+    "metric_type": "COSINE",
+    "params": {
+        "M": 16, "efConstruction": 200
+    }
+}
 
-print("Connecting to Milvus Server...successfully")
+milvus_collection.create_index("vector_text", vector_index_params)
+milvus_collection.create_index("vector_summary", vector_index_params)
+milvus_collection.create_index("vector_questions", vector_index_params)
 
+milvus_collection.load()
 
 __all__ = ["milvus_collection"]

+ 8 - 0
applications/utils/milvus/field.py

@@ -54,6 +54,14 @@ fields = [
         max_capacity=5,
         description="隐含问题",
     ),
+FieldSchema(
+        name="entities",
+        dtype=DataType.ARRAY,
+        element_type=DataType.VARCHAR,
+        max_length=200,
+        max_capacity=5,
+        description="命名实体",
+    ),
     FieldSchema(name="topic_purity", dtype=DataType.FLOAT),
     FieldSchema(name="tokens", dtype=DataType.INT64),
 ]

+ 2 - 1
applications/utils/milvus/functions.py

@@ -10,7 +10,8 @@ async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
     :param data: insert data
     :return:
     """
-    return await asyncio.to_thread(collection.insert, [data])
+    res = await asyncio.to_thread(collection.insert, [data])
+    print(res)
 
 
 async def async_search_chunk(

+ 4 - 2
applications/utils/mysql/mapper.py

@@ -78,7 +78,8 @@ class ContentChunks(BaseMySQLClient):
     async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
         query = """
             UPDATE content_chunks
-            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, keywords = %s, questions = %s, chunk_status = %s
+            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, 
+                keywords = %s, questions = %s, chunk_status = %s, entities = %s
             WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
         """
         return await self.pool.async_save(
@@ -92,8 +93,9 @@ class ContentChunks(BaseMySQLClient):
                 json.dumps(chunk.keywords),
                 json.dumps(chunk.questions),
                 new_status,
+                json.dumps(chunk.entities),
                 chunk.doc_id,
                 chunk.chunk_id,
-                ori_status,
+                ori_status
             ),
         )

+ 6 - 3
routes/buleprint.py

@@ -1,3 +1,5 @@
+import uuid
+
 from quart import Blueprint, jsonify, request
 
 from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
@@ -23,11 +25,12 @@ def server_routes(mysql_db, vector_db):
     @server_bp.route("/chunk", methods=["POST"])
     async def chunk():
         body = await request.get_json()
-        text = body.get("text")
+        text = body.get("text", "")
+        text = text.strip()
         if not text:
             return jsonify({"error": "error  text"})
-
-        chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig())
+        doc_id = f"doc-{uuid.uuid4()}"
+        chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id)
         doc_id = await chunk_task.deal(body)
         return jsonify({"doc_id": doc_id})