Просмотр исходного кода

Merge branch 'luojunhui/feature/20251013-add-pdf-chunking' of Server/rag_server into master

luojunhui 1 день назад
Родитель
Сommit
4fd8379935

+ 8 - 1
applications/async_task/__init__.py

@@ -1,7 +1,14 @@
 from .chunk_task import ChunkEmbeddingTask
+from .chunk_task import ChunkBooksTask
 from .delete_task import DeleteTask
 from .auto_rechunk_task import AutoRechunkTask
 from .build_graph import BuildGraph
 
 
-__all__ = ["ChunkEmbeddingTask", "DeleteTask", "AutoRechunkTask", "BuildGraph"]
+__all__ = [
+    "ChunkEmbeddingTask",
+    "DeleteTask",
+    "AutoRechunkTask",
+    "BuildGraph",
+    "ChunkBooksTask",
+]

+ 101 - 1
applications/async_task/chunk_task.py

@@ -6,7 +6,7 @@ from applications.api import get_basic_embedding
 from applications.utils.async_utils import run_tasks_with_asyncio_task_group
 from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
 from applications.utils.milvus import async_insert_chunk
-from applications.utils.mysql import ContentChunks, Contents
+from applications.utils.mysql import Books, ContentChunks, Contents
 from applications.utils.nlp import num_tokens
 from applications.config import Chunk, DEFAULT_MODEL
 from applications.config import ELASTIC_SEARCH_INDEX
@@ -17,6 +17,7 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         super().__init__(doc_id)
         self.chunk_manager = None
         self.content_manager = None
+        self.book_manager = None
         self.mysql_client = resource.mysql_client
         self.milvus_client = resource.milvus_client
         self.es_client = resource.es_client
@@ -29,6 +30,7 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
     def init_processer(self):
         self.content_manager = Contents(self.mysql_client)
         self.chunk_manager = ContentChunks(self.mysql_client)
+        self.book_manager = Books(self.mysql_client)
 
     async def _chunk_each_content(self, doc_id: str, data: dict) -> List[Chunk]:
         title, text = data.get("title", "").strip(), data["text"].strip()
@@ -260,3 +262,101 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
 
         asyncio.create_task(_process())
         return self.doc_id
+
+
+class ChunkBooksTask(ChunkEmbeddingTask):
+    """图书类型分块任务"""
+
+    BOOK_PDF_DATASET_ID = 21
+    BOOK_PDF_TYPE = 3
+
+    async def _process_each_book(self, book_id):
+        result = await self.book_manager.get_book_extract_detail(book_id=book_id)
+        extract_result = result[0]["extract_result"]
+        book_name = result[0]["book_name"]
+        book_oss_path = result[0]["book_oss_path"]
+        book_texts = [
+            i["text"] for i in json.loads(extract_result) if i["type"] == "text"
+        ]
+
+        # first insert into contents
+        flag = await self.content_manager.insert_content(
+            self.doc_id,
+            book_oss_path,
+            self.BOOK_PDF_TYPE,
+            book_name,
+            self.BOOK_PDF_DATASET_ID,
+            ext=None,
+        )
+        if not flag:
+            return []
+        else:
+            raw_chunks = await self.chunk_books(
+                sentence_list=book_texts,
+                text_type=self.BOOK_PDF_TYPE,
+                dataset_id=self.BOOK_PDF_DATASET_ID,
+            )
+            if not raw_chunks:
+                await self.content_manager.update_content_status(
+                    doc_id=self.doc_id,
+                    ori_status=self.INIT_STATUS,
+                    new_status=self.FAILED_STATUS,
+                )
+                return []
+
+            await self.content_manager.update_content_status(
+                doc_id=self.doc_id,
+                ori_status=self.INIT_STATUS,
+                new_status=self.PROCESSING_STATUS,
+            )
+            return raw_chunks
+
+    async def deal(self, data):
+        book_id = data.get("book_id", None)
+        if not book_id:
+            return {"error": "Book id should not be None"}
+
+        self.init_processer()
+        # LOCK
+        acquire_lock = await self.book_manager.update_book_chunk_status(
+            book_id=book_id,
+            ori_status=self.INIT_STATUS,
+            new_status=self.PROCESSING_STATUS,
+        )
+        if not acquire_lock:
+            return {"info": "book is processing or processed"}
+
+        async def _process():
+            chunks = await self._process_each_book(book_id)
+            if not chunks:
+                return
+
+            # # dev
+            # for chunk in tqdm(chunks):
+            #     await self.save_each_chunk(chunk)
+
+            await run_tasks_with_asyncio_task_group(
+                task_list=chunks,
+                handler=self.save_each_chunk,
+                description="处理单篇文章分块",
+                unit="chunk",
+                max_concurrency=10,
+            )
+
+            await self.content_manager.update_content_status(
+                doc_id=self.doc_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
+            )
+
+            await self.book_manager.update_book_chunk_status(
+                book_id=book_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
+            )
+
+        asyncio.create_task(_process())
+        return self.doc_id
+
+
+__all__ = ["ChunkEmbeddingTask", "ChunkBooksTask"]

+ 1 - 1
applications/config/base_chunk.py

@@ -26,7 +26,7 @@ class Chunk:
 class ChunkerConfig:
     target_tokens: int = 256
     max_tokens: int = 2048
-    min_tokens: int = 64
+    min_tokens: int = 16
     boundary_threshold: float = 0.8
     min_sent_per_chunk: int = 3
     max_sent_per_chunk: int = 10

+ 1 - 0
applications/prompts/__init__.py

@@ -0,0 +1 @@
+from .build_graph import extract_entity_and_graph

+ 63 - 0
applications/prompts/build_graph.py

@@ -0,0 +1,63 @@
+import json
+
+
+def extract_entity_and_graph(text: str) -> str:
+    """
+    通用知识抽取 Prompt 生成器。
+    从任意输入文本中提取实体、关系和概念信息。
+    输出 JSON 格式,兼容大模型调用,不会出现格式冲突或转义问题。
+    """
+    safe_text = json.dumps(text, ensure_ascii=False)
+    prompt = f"""
+### 角色设定
+你是一名专业的知识抽取助手,请从输入的文本中识别出关键信息并输出结构化结果。
+
+### 输入文本:
+{safe_text}
+
+### 抽取目标
+请提取以下三类信息:
+1. **entities(实体)**  
+   - 指文本中出现的具体对象、人物、机构、地点、技术、产品、事件等。  
+   - 每个实体需包含:名称、类型、别名、描述
+   - 只需要获取主要,关键内容包含的实体,知识中的示例,样例等补充信息不需要提取实体
+
+2. **relations(关系)**  
+   - 指实体之间的语义联系。  
+   - 常见关系类型包括(不限于):属于、隶属、依赖、控制、合作、位于、应用于、开发、影响、由...定义、用于。  
+   - 每条关系需包含:source、target、relation_type、evidence、confidence。
+3. **concepts(概念)**  
+   - 指文本涉及的主题、核心技术、思想、学科领域、话题关键词等。
+
+### 输出格式
+请严格输出以下 JSON 结构,禁止输出任何解释性文字、注释或 Markdown 代码块。
+
+{{
+  "entities": [
+    {{
+      "name": "string", "type": "string", "aliases": ["string"], "description": "string"
+    }}
+  ],
+  "relations": [
+    {{
+    "source": "string", "target": "string", "relation_type": "string", "evidence": "string", "confidence": 0.0
+    }}
+  ],
+  "concepts": ["string"]
+}}
+
+### 输出规则
+1. 严格输出合法 JSON 格式,可直接解析;
+2. 所有字段必须存在,即使为空数组;
+3. 若未检测到任何内容,请输出:
+   {{
+     "entities": [],
+     "relations": [],
+     "concepts": []
+   }};
+4. `confidence` 为 0.0 ~ 1.0 之间的小数;
+5. 禁止推理未出现在文本中的实体或关系;
+6. 若出现模糊信息,请保持描述中立;
+7. 输出中不允许包含解释性文字、注释、示例或 Markdown。
+"""
+    return prompt

+ 42 - 0
applications/utils/chunks/topic_aware_chunking.py

@@ -52,6 +52,15 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
             "embeddings": sentences_embeddings,
         }
 
+    async def _book_chunk(self, sentence_list: List[str]) -> Dict[str, Any]:
+        sentences_embeddings = await self._encode_batch(sentence_list)
+        boundaries = self.detect_boundaries_v2(sentence_list, sentences_embeddings)
+        return {
+            "sentence_list": sentence_list,
+            "boundaries": boundaries,
+            "embeddings": sentences_embeddings,
+        }
+
 
 class TopicAwarePackerV1(TopicAwareChunker):
     def _pack_v1(
@@ -175,3 +184,36 @@ class TopicAwarePackerV2(TopicAwareChunker):
             text_type=text_type,
             dataset_id=dataset_id,
         )
+
+    async def chunk_books(
+        self, sentence_list: List[str], text_type: int, dataset_id: int
+    ) -> List[Chunk]:
+        raw_info = await self._book_chunk(sentence_list=sentence_list)
+        if not raw_info:
+            return []
+
+        return self._pack_v2(
+            sentence_list=raw_info["sentence_list"],
+            boundaries=raw_info["boundaries"],
+            embeddings=raw_info["embeddings"],
+            text_type=text_type,
+            dataset_id=dataset_id,
+        )
+
+    async def chunk_books_raw(
+        self, sentence_list: List[str], text_type: int, dataset_id: int
+    ):
+        chunks = []
+        for index, text in enumerate(sentence_list, 1):
+            chunks.append(
+                Chunk(
+                    doc_id=self.doc_id,
+                    dataset_id=dataset_id,
+                    text=text,
+                    chunk_id=index,
+                    tokens=num_tokens(text),
+                    text_type=text_type,
+                    status=1,
+                )
+            )
+        return chunks

+ 9 - 1
applications/utils/mysql/__init__.py

@@ -1,7 +1,15 @@
+from .books import Books
 from .pool import DatabaseManager
 from .mapper import Dataset, ChatResult
 from .content_chunks import ContentChunks
 from .contents import Contents
 
 
-__all__ = ["Contents", "ContentChunks", "DatabaseManager", "Dataset", "ChatResult"]
+__all__ = [
+    "Contents",
+    "ContentChunks",
+    "DatabaseManager",
+    "Dataset",
+    "ChatResult",
+    "Books",
+]

+ 25 - 0
applications/utils/mysql/books.py

@@ -0,0 +1,25 @@
+from .base import BaseMySQLClient
+
+
+class Books(BaseMySQLClient):
+    async def get_books(self):
+        query = """
+            SELECT book_id, book_name, book_oss_path, extract_status
+            FROM books
+            WHERE status = 1;
+        """
+        return await self.pool.async_fetch(query=query)
+
+    async def get_book_extract_detail(self, book_id):
+        query = """
+            SELECT book_name, book_oss_path, extract_result FROM books WHERE book_id = %s;
+        """
+        return await self.pool.async_fetch(query=query, params=(book_id,))
+
+    async def update_book_chunk_status(self, book_id, ori_status, new_status):
+        query = """
+            UPDATE books SET chunk_status = %s WHERE book_id = %s and chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, book_id, ori_status)
+        )

+ 105 - 0
applications/utils/nlp/boundary_detector.py

@@ -68,3 +68,108 @@ class BoundaryDetector(ChunkerConfig):
                 )
 
         return boundaries
+
+    def detect_boundaries_v2(
+        self, sentence_list: List[str], embs: np.ndarray, debug: bool = False
+    ) -> List[int]:
+        """
+        约束:相邻 boundary(含开头到第一个 boundary)之间的句子数 ∈ [3, 10]
+        boundary 的含义:作为“段落末句”的索引(与 pack 时的 b 含义一致)
+        """
+        n = len(sentence_list)
+        if n <= 1 or embs is None or len(embs) != n:
+            return []
+
+        # --- 基础打分 ---
+        sims = np.array([self.cosine_sim(embs[i], embs[i + 1]) for i in range(n - 1)])
+        cut_scores = 1 - sims
+        cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else np.array([])
+
+        # 组合信号:内容转折/图片编号等
+        adj_scores = np.zeros_like(cut_scores)
+        for i in range(len(cut_scores)):
+            sent_to_check = sentence_list[i] if i < n else sentence_list[-1]
+            snippet = sent_to_check[-20:] if sent_to_check else ""
+            adj_scores[i] = (
+                cut_scores[i]
+                + self.turn_signal(snippet)
+                + self.figure_signal(sent_to_check)
+            )
+
+        # --- 3-10 句强约束切分 ---
+        MIN_SIZE = self.min_sent_per_chunk
+        MAX_SIZE = self.max_sent_per_chunk
+        thr = getattr(self, "boundary_threshold", 0.5)
+
+        boundaries: List[int] = []
+        last_boundary = -1  # 作为上一个“段末句”的索引(开头前为 -1)
+
+        best_idx = None  # 记录当前窗口内(已达 MIN_SIZE)的最高分切点
+        best_score = -1e9
+
+        for i in range(n - 1):  # i 表示把 i 作为“段末句”的候选
+            seg_len = i - last_boundary  # 若切在 i,本段包含的句数 = i - last_boundary
+
+            # 更新当前窗口最佳候选(仅在达到最低长度后才可记为候选)
+            if seg_len >= MIN_SIZE:
+                if adj_scores[i] > best_score:
+                    best_score = float(adj_scores[i])
+                    best_idx = i
+
+            cut_now = False
+            cut_at = None
+
+            if seg_len < MIN_SIZE:
+                # 不足 3 句,绝不切
+                pass
+            elif adj_scores[i] >= thr and seg_len <= MAX_SIZE:
+                # 在 [3,10] 区间且过阈值,直接切
+                cut_now = True
+                cut_at = i
+            elif seg_len == MAX_SIZE:
+                # 已到 10 句必须切:优先用窗口内最高分位置
+                cut_now = True
+                cut_at = best_idx if best_idx is not None else i
+
+            if cut_now:
+                boundaries.append(cut_at)
+                last_boundary = cut_at
+                best_idx = None
+                best_score = -1e9
+
+            if debug:
+                print(
+                    f"[{i}] sim={sims[i]:.3f}, cut={cut_scores[i]:.3f}, "
+                    f"adj={adj_scores[i]:.3f}, len={seg_len}, "
+                    f"cut={'Y@' + str(cut_at) if cut_now else 'N'}"
+                )
+
+        # --- 收尾:避免最后一段 < 3 句 ---
+        # pack 时会额外补上末尾 n-1 作为最终 boundary,因此尾段长度为 (n-1 - last_boundary)
+        tail_len = (n - 1) - last_boundary
+        if tail_len < MIN_SIZE and boundaries:
+            # 需要把“最后一个 boundary”往前/后微调到一个可行区间内
+            prev_last = boundaries[-2] if len(boundaries) >= 2 else -1
+            # 新的最后切点需满足:
+            # 1) 前一段长度在 [3,10] => j ∈ [prev_last+3, prev_last+10]
+            # 2) 尾段长度在 [3,10] => j ∈ [n-1-10, n-1-3]
+            lower = max(prev_last + MIN_SIZE, (n - 1) - MAX_SIZE)
+            upper = min(prev_last + MAX_SIZE, (n - 1) - MIN_SIZE)
+
+            if lower <= upper:
+                # 在允许区间里找 adj_scores 最高的位置
+                window = adj_scores[lower : upper + 1]
+                j = int(np.argmax(window)) + lower
+                if j != boundaries[-1]:
+                    boundaries[-1] = j
+                    if debug:
+                        print(
+                            f"[fix-tail] move last boundary -> {j}, tail_len={n - 1 - j}"
+                        )
+            else:
+                # 没有可行区间:退化为合并尾段(删掉最后一个 boundary)
+                dropped = boundaries.pop()
+                if debug:
+                    print(f"[fix-tail] drop last boundary {dropped} to avoid tiny tail")
+
+        return boundaries

+ 11 - 1
routes/blueprint.py

@@ -10,7 +10,7 @@ from quart_cors import cors
 from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
 from applications.async_task import AutoRechunkTask, BuildGraph
-from applications.async_task import ChunkEmbeddingTask, DeleteTask
+from applications.async_task import ChunkEmbeddingTask, DeleteTask, ChunkBooksTask
 from applications.config import (
     DEFAULT_MODEL,
     LOCAL_MODEL_CONFIG,
@@ -85,6 +85,16 @@ async def chunk():
     return jsonify({"doc_id": doc_id})
 
 
+@server_bp.route("/chunk_book", methods=["POST"])
+async def chunk_book():
+    body = await request.get_json()
+    resource = get_resource_manager()
+    doc_id = f"doc-{uuid.uuid4()}"
+    chunk_task = ChunkBooksTask(doc_id=doc_id, resource=resource)
+    doc_id = await chunk_task.deal(body)
+    return jsonify({"doc_id": doc_id})
+
+
 @server_bp.route("/search", methods=["POST"])
 async def search():
     """