Browse Source

Merge branch 'feature/luojunhui/2025-09-12-llm-chunks' of Server/llm_vector_server into master

luojunhui 3 weeks ago
parent
commit
eb49b08659

+ 5 - 7
Dockerfile

@@ -9,23 +9,21 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
     PYTHONUNBUFFERED=1 \
     PIP_DISABLE_PIP_VERSION_CHECK=on \
     TZ=Asia/Shanghai \
-    PATH="/root/.local/bin:$PATH"
+    PATH="/root/.local/bin:$PATH" \
+    DEBIAN_FRONTEND=noninteractive
 
-# 安装系统依赖(构建 wheel、时区等
+# 安装系统依赖(如果 requirements 里需要编译 C 扩展
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         curl \
         tzdata \
     && rm -rf /var/lib/apt/lists/*
 
-# 设置时区
-RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
-
-# 复制 requirements 并安装依赖
+# 先复制 requirements 并安装依赖(利用缓存)
 COPY requirements.txt .
 RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
 
-# 复制项目文件
+# 复制项目文件
 COPY . .
 
 # 暴露端口

+ 2 - 1
applications/api/__init__.py

@@ -1,3 +1,4 @@
+from .deepseek import fetch_deepseek_completion
 from .embedding import get_basic_embedding
 
-__all__ = ["get_basic_embedding"]
+__all__ = ["get_basic_embedding", "fetch_deepseek_completion"]

+ 55 - 0
applications/api/deepseek.py

@@ -0,0 +1,55 @@
+"""
+@author: luojunhui
+@description: deepseek 官方api (async版)
+"""
+
+from __future__ import annotations
+
+import json
+from typing import Dict, List, Optional
+from openai import AsyncOpenAI
+
+from applications.config import DEEPSEEK_MODEL
+from applications.config import DEEPSEEK_API_KEY
+
+
+async def fetch_deepseek_completion(
+    model: str,
+    prompt: str,
+    output_type: str = "text",
+    tool_calls: bool = False,
+    tools: List[Dict] = None,
+) -> Optional[Dict | List]:
+    messages = [{"role": "user", "content": prompt}]
+    kwargs = {
+        "model": DEEPSEEK_MODEL.get(model, "deepseek-chat"),
+        "messages": messages,
+    }
+
+    # add tool calls
+    if tool_calls and tools:
+        kwargs["tools"] = tools
+        kwargs["tool_choice"] = "auto"
+
+    client = AsyncOpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
+
+    if output_type == "json":
+        kwargs["response_format"] = {"type": "json_object"}
+
+    try:
+        response = await client.chat.completions.create(**kwargs)
+        choice = response.choices[0]
+
+        if output_type == "text":
+            return choice.message.content  # 只返回文本
+        elif output_type == "json":
+            return json.loads(choice.message.content)
+        else:
+            raise ValueError(f"Invalid output_type: {output_type}")
+
+    except Exception as e:
+        print(f"[ERROR] fetch_deepseek_completion failed: {e}")
+        return None
+
+
+__all__ = ["fetch_deepseek_completion"]

+ 1 - 1
applications/api/embedding.py

@@ -1,5 +1,5 @@
 from applications.config import LOCAL_MODEL_CONFIG, VLLM_SERVER_URL, DEV_VLLM_SERVER_URL
-from applications.utils import AsyncHttpClient
+from applications.utils.http import AsyncHttpClient
 
 
 async def get_basic_embedding(text: str, model: str, dev=False):

+ 4 - 0
applications/async_task/__init__.py

@@ -0,0 +1,4 @@
+from .chunk_task import ChunkEmbeddingTask
+
+
+__all__ = ["ChunkEmbeddingTask"]

+ 179 - 0
applications/async_task/chunk_task.py

@@ -0,0 +1,179 @@
+import asyncio
+import uuid
+from typing import List
+
+from applications.api import get_basic_embedding
+from applications.utils.async_utils import run_tasks_with_asyncio_task_group
+from applications.utils.mysql import ContentChunks, Contents
+from applications.utils.chunks import TopicAwareChunker, LLMClassifier
+from applications.utils.milvus import async_insert_chunk
+from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
+
+
+class ChunkEmbeddingTask(TopicAwareChunker):
+    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
+        super().__init__(cfg, doc_id)
+        self.content_chunk_processor = None
+        self.contents_processor = None
+        self.mysql_pool = mysql_pool
+        self.vector_pool = vector_pool
+        self.classifier = LLMClassifier()
+
+    @staticmethod
+    async def get_embedding_list(text: str) -> List:
+        return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
+
+    def init_processer(self):
+        self.contents_processor = Contents(self.mysql_pool)
+        self.content_chunk_processor = ContentChunks(self.mysql_pool)
+
+    async def process_content(self, doc_id: str, text: str, text_type: int) -> List[Chunk]:
+        flag = await self.contents_processor.insert_content(doc_id, text, text_type)
+        if not flag:
+            return []
+        else:
+            raw_chunks = await self.chunk(text, text_type)
+            if not raw_chunks:
+                await self.contents_processor.update_content_status(
+                    doc_id=doc_id,
+                    ori_status=self.INIT_STATUS,
+                    new_status=self.FAILED_STATUS,
+                )
+                return []
+
+            await self.contents_processor.update_content_status(
+                doc_id=doc_id,
+                ori_status=self.INIT_STATUS,
+                new_status=self.PROCESSING_STATUS,
+            )
+            return raw_chunks
+
+    async def process_each_chunk(self, chunk: Chunk):
+        # insert
+        flag = await self.content_chunk_processor.insert_chunk(chunk)
+        if not flag:
+            print("插入文本失败")
+            return
+
+        acquire_lock = await self.content_chunk_processor.update_chunk_status(
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            ori_status=self.INIT_STATUS,
+            new_status=self.PROCESSING_STATUS,
+        )
+        if not acquire_lock:
+            print("抢占文本分块锁失败")
+            return
+
+        completion = await self.classifier.classify_chunk(chunk)
+        if not completion:
+            await self.content_chunk_processor.update_chunk_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
+            )
+            print("从deepseek获取信息失败")
+            return
+
+        update_flag = await self.content_chunk_processor.set_chunk_result(
+            chunk=completion,
+            ori_status=self.PROCESSING_STATUS,
+            new_status=self.FINISHED_STATUS,
+        )
+        if not update_flag:
+            await self.content_chunk_processor.update_chunk_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
+            )
+            return
+
+        await self.save_to_milvus(completion)
+
+    async def save_to_milvus(self, chunk: Chunk):
+        """
+        :param chunk: each single chunk
+        :return:
+        """
+        # 抢锁
+        acquire_lock = await self.content_chunk_processor.update_embedding_status(
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            new_status=self.PROCESSING_STATUS,
+            ori_status=self.INIT_STATUS,
+        )
+        if not acquire_lock:
+            print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
+            return
+        try:
+            data = {
+                "doc_id": chunk.doc_id,
+                "chunk_id": chunk.chunk_id,
+                "vector_text": await self.get_embedding_list(chunk.text),
+                "vector_summary": await self.get_embedding_list(chunk.summary),
+                "vector_questions": await self.get_embedding_list(
+                    ",".join(chunk.questions)
+                ),
+                "topic": chunk.topic,
+                "domain": chunk.domain,
+                "task_type": chunk.task_type,
+                "summary": chunk.summary,
+                "keywords": chunk.keywords,
+                "entities": chunk.entities,
+                "concepts": chunk.concepts,
+                "questions": chunk.questions,
+                "topic_purity": chunk.topic_purity,
+                "tokens": chunk.tokens,
+            }
+            await async_insert_chunk(self.vector_pool, data)
+            await self.content_chunk_processor.update_embedding_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
+            )
+        except Exception as e:
+            await self.content_chunk_processor.update_embedding_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
+            )
+            print(f"存入向量数据库失败", e)
+
+    async def deal(self, data):
+        text = data.get("text", "")
+        text = text.strip()
+        text_type = data.get("text_type", 1)
+        if not text:
+            return None
+
+        self.init_processer()
+
+        async def _process():
+            chunks = await self.process_content(self.doc_id, text, text_type)
+            if not chunks:
+                return
+
+            # # dev
+            # for chunk in chunks:
+            #     await self.process_each_chunk(chunk)
+
+            await run_tasks_with_asyncio_task_group(
+                task_list=chunks,
+                handler=self.process_each_chunk,
+                description="处理单篇文章分块",
+                unit="chunk",
+                max_concurrency=10,
+            )
+
+            await self.contents_processor.update_content_status(
+                doc_id=self.doc_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
+            )
+
+        asyncio.create_task(_process())
+        return self.doc_id

+ 0 - 0
applications/chunks/__init__.py


+ 0 - 452
applications/chunks/topic_aware_chunking.py

@@ -1,452 +0,0 @@
-"""
-主题感知分块
-"""
-
-from __future__ import annotations
-
-import re, uuid, math
-from dataclasses import dataclass, field, asdict
-from typing import List, Dict, Any, Tuple, Optional
-
-import optuna
-import numpy as np
-
-from sentence_transformers import SentenceTransformer, util
-
-from applications.utils import SplitTextIntoSentences, detect_language, num_tokens
-
-
-# ---------- Utilities ----------
-def simple_sent_tokenize(text: str) -> List[str]:
-    text = re.sub(r"\n{2,}", "\n", text)
-    parts = re.split(r"([。!?!?;;]+)\s*|\n+", text)
-    sents, buf = [], ""
-    for p in parts:
-        if p is None:
-            continue
-        if re.match(r"[。!?!?;;]+", p or ""):
-            buf += p or ""
-            if buf.strip():
-                sents.append(buf.strip())
-            buf = ""
-        elif p.strip() == "":
-            if buf.strip():
-                sents.append(buf.strip())
-                buf = ""
-        else:
-            buf += p or ""
-    if buf.strip():
-        sents.append(buf.strip())
-
-    merged = []
-    for s in sents:
-        if merged and (len(s) < 10 or len(merged[-1]) < 10):
-            merged[-1] += " " + s
-        else:
-            merged.append(s)
-    return [s for s in merged if s.strip()]
-
-
-def approx_tokens(text: str) -> int:
-    """Cheap token estimator (≈4 chars/token for zh, ≈0.75 words/token for en)."""
-    # This is a heuristic; replace with tiktoken if desired.
-    cjk = re.findall(r"[\u4e00-\u9fff]", text)
-    others = re.sub(r"[\u4e00-\u9fff]", " ", text).split()
-    return max(1, int(len(cjk) / 2.5 + len(others) / 0.75))
-
-
-# ---------- Knowledge Graph Stub ----------
-class KGClassifier:
-    """
-    Hierarchical topic classifier using embedding prototypes per node.
-    Replace `nodes` with your KG; each node keeps a centroid embedding.
-    """
-
-    def __init__(self, model: SentenceTransformer, kg_spec: Dict[str, Any]):
-        """
-        kg_spec example:
-        {
-          "root": {
-            "name": "root",
-            "children": [
-              {"name": "Computer Science", "children":[
-                  {"name":"NLP", "children":[{"name":"RAG", "children":[]}]}]},
-              {"name": "Finance", "children":[{"name":"AP/AR", "children":[]}]}]}
-        }
-        """
-        self.model = model
-        self.root = kg_spec["root"]
-        self._embed_cache = {}  # name -> vector
-
-        def build_centroid(node):
-            name = node["name"]
-            if name not in self._embed_cache:
-                self._embed_cache[name] = self.model.encode(
-                    name, normalize_embeddings=True
-                )
-            for ch in node.get("children", []):
-                build_centroid(ch)
-
-        build_centroid(self.root)
-
-    def classify(self, text_emb: np.ndarray, topk: int = 3) -> Tuple[List[str], float]:
-        """
-        Return (topic_path, purity). Purity is soft max margin across levels.
-        """
-        path, purities = [], []
-        node = self.root
-        while True:
-            # score current node children
-            children = node.get("children", [])
-            if not children:
-                break
-            scores = []
-            for ch in children:
-                vec = self._embed_cache[ch["name"]]
-                scores.append((ch, float(util.cos_sim(text_emb, vec).item())))
-            scores.sort(key=lambda x: x[1], reverse=True)
-            best, second = scores[0], (scores[1] if len(scores) > 1 else (None, -1.0))
-            path.append(best[0]["name"])
-            margin = max(0.0, (best[1] - max(second[1], -1.0)))
-            purities.append(1 / (1 + math.exp(-5 * margin)))  # squash to (0,1)
-            node = best[0]
-        purity = float(np.mean(purities)) if purities else 1.0
-        return path, purity
-
-
-# ---------- Core Chunker ----------
-@dataclass
-class Chunk:
-    id: str
-    text: str
-    tokens: int
-    topics: List[str] = field(default_factory=list)
-    topic_purity: float = 1.0
-    meta: Dict[str, Any] = field(default_factory=dict)
-
-
-@dataclass
-class ChunkerConfig:
-    model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
-    target_tokens: int = 80
-    max_tokens: int = 80
-    overlap_ratio: float = 0.12
-    boundary_threshold: float = 0.50  # similarity drop boundary (lower -> more cuts)
-    min_sent_per_chunk: int = 1
-    max_sent_per_chunk: int = 8
-    enable_adaptive_boundary: bool = True
-    enable_kg: bool = True
-    topic_purity_floor: float = 0.65
-    kg_topk: int = 3
-
-
-class TopicAwareChunker:
-    def __init__(self, cfg: ChunkerConfig, kg_spec: Optional[Dict[str, Any]] = None):
-        self.cfg = cfg
-        self.model = SentenceTransformer(
-            cfg.model_name, device="cpu"
-        )  # set gpu if available
-        self.model.max_seq_length = 512
-        self.kg = (
-            KGClassifier(self.model, kg_spec) if (cfg.enable_kg and kg_spec) else None
-        )
-
-    # ---------- Public API ----------
-    def chunk(self, text: str) -> List[Chunk]:
-        sents = simple_sent_tokenize(text)
-        if not sents:
-            return []
-        sent_embs = self.model.encode(sents, normalize_embeddings=True)
-        boundaries = self._detect_boundaries(sents, sent_embs)
-        raw_chunks = self._pack_by_boundaries(sents, sent_embs, boundaries)
-        final_chunks = self._classify_and_refine(raw_chunks)
-        return final_chunks
-
-    # ---------- Boundary detection ----------
-    def _detect_boundaries(self, sents: List[str], embs: np.ndarray) -> List[int]:
-        sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
-        cut_scores = 1 - sims  # higher means more likely boundary
-
-        # use np.ptp instead of ndarray.ptp (NumPy 2.0 compatibility)
-        rng = np.ptp(cut_scores) if np.ptp(cut_scores) > 0 else 1e-6
-        cut_scores = (cut_scores - cut_scores.min()) / (rng + 1e-6)
-
-        boundaries = []
-        for i, score in enumerate(cut_scores):
-            # 对应的是句对 (i, i+1),这里可以检查 sents[i] 或 sents[i+1]
-            sent_to_check = sents[i] if i < len(sents) else sents[-1]
-            # 防御性写法,避免越界
-            snippet = sent_to_check[-20:] if sent_to_check else ""
-
-            turn = (
-                0.1
-                if re.search(
-                    r"(因此|但是|综上|然而|另一方面|In conclusion|However|Therefore)",
-                    snippet,
-                )
-                else 0.0
-            )
-            fig = (
-                0.1
-                if re.search(
-                    r"(见下图|如表|表\s*\d+|图\s*\d+|Figure|Table)", sent_to_check
-                )
-                else 0.0
-            )
-
-            adj_score = score + turn + fig
-            if adj_score >= self.cfg.boundary_threshold:
-                boundaries.append(i)
-
-        return boundaries
-
-    # ---------- Packing ----------
-    def _pack_by_boundaries(
-        self, sents: List[str], embs: np.ndarray, boundaries: List[int]
-    ) -> List[Chunk]:
-        """Greedy pack around boundaries to meet target length & sentence counts."""
-        boundary_set = set(boundaries)
-        chunks: List[Chunk] = []
-        start = 0
-        n = len(sents)
-        while start < n:
-            end = start
-            cur_tokens = 0
-            sent_count = 0
-            last_boundary = start - 1
-            while end < n and sent_count < self.cfg.max_sent_per_chunk:
-                cur_tokens = approx_tokens(" ".join(sents[start : end + 1]))
-                sent_count += 1
-                if cur_tokens >= self.cfg.target_tokens:
-                    # try to cut at nearest boundary to 'end'
-                    cut = end
-                    # search backward to nearest boundary within window
-                    for b in range(end, start - 1, -1):
-                        if b in boundary_set:
-                            cut = b
-                            break
-                    # avoid too small chunks
-                    if cut - start + 1 >= self.cfg.min_sent_per_chunk:
-                        end = cut
-                    break
-                end += 1
-
-            # finalize chunk
-            text = " ".join(sents[start : end + 1]).strip()
-            tokens = approx_tokens(text)
-            chunk = Chunk(id=str(uuid.uuid4()), text=text, tokens=tokens)
-            chunks.append(chunk)
-
-            # soft overlap (append tail sentences of current as head of next)
-            if self.cfg.overlap_ratio > 0 and end + 1 < n:
-                overlap_tokens = int(tokens * self.cfg.overlap_ratio)
-                # approximate by sentences
-                overlap_sents = []
-                t = 0
-                for s in reversed(sents[start : end + 1]):
-                    t += approx_tokens(s)
-                    overlap_sents.append(s)
-                    if t >= overlap_tokens:
-                        break
-                # prepend to next start by reducing start index backward (not altering original sents)
-            start = end + 1
-        return chunks
-
-    # ---------- KG classify & refine ----------
-    def _classify_and_refine(self, chunks: List[Chunk]) -> List[Chunk]:
-        if not self.kg:
-            return chunks
-        refined: List[Chunk] = []
-        for ch in chunks:
-            emb = self.model.encode(ch.text, normalize_embeddings=True)
-            topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
-            ch.topics, ch.topic_purity = topics, purity
-            # If purity is low, try a secondary split inside the chunk
-            if purity < self.cfg.topic_purity_floor:
-                sub = self._refine_chunk_by_topic(ch)
-                refined.extend(sub)
-            else:
-                refined.append(ch)
-        return refined
-
-    def _refine_chunk_by_topic(self, chunk: Chunk) -> List[Chunk]:
-        """Second-pass split inside a low-purity chunk."""
-        sents = simple_sent_tokenize(chunk.text)
-        if len(sents) <= self.cfg.min_sent_per_chunk * 2:
-            return [chunk]
-        embs = self.model.encode(sents, normalize_embeddings=True)
-        # force more boundaries by lowering threshold a bit
-        orig = self.cfg.boundary_threshold
-        try:
-            self.cfg.boundary_threshold = max(0.3, orig - 0.1)
-            boundaries = self._detect_boundaries(sents, embs)
-            sub_chunks = self._pack_by_boundaries(sents, embs, boundaries)
-            # inherit topics again
-            final = []
-            for ch in sub_chunks:
-                emb = self.model.encode(ch.text, normalize_embeddings=True)
-                topics, purity = self.kg.classify(emb, topk=self.cfg.kg_topk)
-                ch.topics, ch.topic_purity = topics, purity
-                final.append(ch)
-            return final
-        finally:
-            self.cfg.boundary_threshold = orig
-
-
-# ---------- Auto-tuning (unsupervised objective) ----------
-class UnsupervisedEvaluator:
-    """
-    Build a score: higher is better.
-    - Intra-chunk coherence (avg similarity of neighboring sentences)
-    - Inter-chunk separation (low similarity of chunk medoids to neighbors)
-    - Length penalty (deviation from target_tokens)
-    - Topic purity reward (if KG is enabled)
-    """
-
-    def __init__(
-        self, model: SentenceTransformer, target_tokens: int, kg_weight: float = 0.5
-    ):
-        self.model = model
-        self.target = target_tokens
-        self.kg_weight = kg_weight
-
-    def score(self, chunks: List[Chunk], kg_present: bool = True) -> float:
-        if not chunks:
-            return -1e6
-        # Intra coherence: reward high
-        intra = []
-        for ch in chunks:
-            sents = simple_sent_tokenize(ch.text)
-            if len(sents) < 2:
-                continue
-            embs = self.model.encode(sents, normalize_embeddings=True)
-            sims = util.cos_sim(embs[:-1], embs[1:]).cpu().numpy().reshape(-1)
-            intra.append(float(np.mean(sims)))
-        intra_score = float(np.mean(intra)) if intra else 0.0
-
-        # Inter separation: penalize adjacent chunk similarity
-        if len(chunks) > 1:
-            reps = self.model.encode(
-                [c.text for c in chunks], normalize_embeddings=True
-            )
-            adj = []
-            for i in range(len(chunks) - 1):
-                adj.append(float(util.cos_sim(reps[i], reps[i + 1]).item()))
-            inter_penalty = float(np.mean(adj))
-        else:
-            inter_penalty = 0.0
-
-        # Length penalty
-        dev = [abs(c.tokens - self.target) / max(1, self.target) for c in chunks]
-        len_penalty = float(np.mean(dev))
-
-        # Topic purity
-        if kg_present:
-            pur = [c.topic_purity for c in chunks]
-            purity = float(np.mean(pur))
-        else:
-            purity = 0.0
-
-        # Final score
-        return (
-            intra_score
-            - 0.6 * inter_penalty
-            - 0.4 * len_penalty
-            + self.kg_weight * purity
-        )
-
-
-def auto_tune_params(
-    raw_texts: List[str],
-    kg_spec: Optional[Dict[str, Any]] = None,
-    n_trials: int = 20,
-    seed: int = 42,
-) -> ChunkerConfig:
-    """Bayesian-like search with Optuna to find a good config on your corpus."""
-
-    def objective(trial: optuna.Trial):
-        cfg = ChunkerConfig(
-            target_tokens=trial.suggest_int("target_tokens", 30, 400, step=10),
-            max_tokens=trial.suggest_int("max_tokens", 30, 520, step=10),
-            overlap_ratio=trial.suggest_float("overlap_ratio", 0.05, 0.25, step=0.05),
-            boundary_threshold=trial.suggest_float(
-                "boundary_threshold", 0.45, 0.75, step=0.05
-            ),
-            min_sent_per_chunk=trial.suggest_int("min_sent_per_chunk", 2, 4),
-            max_sent_per_chunk=trial.suggest_int("max_sent_per_chunk", 8, 16),
-            enable_adaptive_boundary=True,
-            enable_kg=(kg_spec is not None),
-            topic_purity_floor=trial.suggest_float(
-                "topic_purity_floor", 0.55, 0.8, step=0.05
-            ),
-        )
-        chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
-        evaluator = UnsupervisedEvaluator(
-            chunker.model, cfg.target_tokens, kg_weight=0.5 if kg_spec else 0.0
-        )
-
-        # Evaluate across a small sample
-        scores = []
-        for t in raw_texts:
-            chunks = chunker.chunk(t)
-            s = evaluator.score(chunks, kg_present=(kg_spec is not None))
-            scores.append(s)
-        return float(np.mean(scores))
-
-    sampler = optuna.samplers.TPESampler(seed=seed)
-    study = optuna.create_study(direction="maximize", sampler=sampler)
-    study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
-    best_params = study.best_params
-
-    return ChunkerConfig(
-        target_tokens=best_params["target_tokens"],
-        max_tokens=best_params["max_tokens"],
-        overlap_ratio=best_params["overlap_ratio"],
-        boundary_threshold=best_params["boundary_threshold"],
-        min_sent_per_chunk=best_params["min_sent_per_chunk"],
-        max_sent_per_chunk=best_params["max_sent_per_chunk"],
-        enable_adaptive_boundary=True,
-        enable_kg=(kg_spec is not None),
-        topic_purity_floor=best_params["topic_purity_floor"],
-    )
-
-
-# ---------- Example usage ----------
-if __name__ == "__main__":
-    sample_text = """
-    RAG(Retrieval-Augmented Generation)是一种增强生成的技术。
-    在复杂知识问答中,RAG 通过检索相关文档片段来改善答案质量。
-    然而,分块策略会显著影响检索召回与可引用性。
-    因此,我们提出一种主题感知的分块方法,结合 Transformer 边界探测与知识图谱层次分类。
-    然后,我们讲一个新的主题,篮球
-    这个也就是罚球动作。一般原地动作分为两种。
-    第一种原地投篮动作是先下蹲,做好投篮的发力前上举动作,然后竖直向上伸直身体,右臂顺势在身体向上的过程中竖直向上将球向上投出。这种原地投篮的好处是,发力轻松,可以借助身体向上竖直的这个力度的趋势,帮助投篮发力,会让投篮的力气减少很多。尤其是在比赛后半程体力不好的时候,依然可以做到很高的命中略。这种投篮的要领是:主动的竖直向上的意识。我们以前就经常强调竖直起跳和竖直的概念,但是,同样看起来是竖直,但是用出来的效果却很不同,这主要就是技巧的关系了。这个技巧的精髓就在于“主动意识”。在你练习这种投篮的时候,每一次,都要在下蹲以后,明确的在脑子里想着,要竖直向上发力。双腿要竖直向上用力,整个身体也是这样,而且,最为重要的是,你一定要在练习的时候每次都要主动的去想,然后刻意的去竖直向上。这样,长久下去,养成习惯,你的这种投篮才会稳定。这里我们要顺便强调之前的一篇文章,就是录像纠错法,我们这里之所以一再强调要主动意识的竖直上起,就是因为,在录像上,未必能看得出来这个问题。也就是说,你的录像虽然看起来你是竖直起跳的,但是你没有一个主动的也就是刻意的竖直起跳的意识的话,这个球也不是竖直起跳。另外,相反的,如果你在视频上看到自己不是竖直起跳,但是实际上这个球是你使用了竖直起跳的主动意识来发力的。那么,尽管看起来不是很竖直,却依然可以很稳定。也就是说,眼睛会欺骗你,一定要注重你的意识。
-    """
-    kg_spec = {
-        "root": {
-            "name": "root",
-            "children": [
-                {
-                    "name": "Computer Science",
-                    "children": [
-                        {"name": "NLP", "children": [{"name": "RAG", "children": []}]}
-                    ],
-                },
-                {"name": "Finance", "children": [{"name": "AP/AR", "children": []}]},
-                {
-                    "name": "体育",
-                    "children": [
-                        {"name": "篮球", "children": [{"name": "投篮", "children": []}]}
-                    ],
-                },
-            ],
-        }
-    }
-    cfg = auto_tune_params([sample_text], kg_spec=kg_spec, n_trials=10, seed=42)
-    chunker = TopicAwareChunker(cfg, kg_spec=kg_spec)
-    chunks = chunker.chunk(sample_text)
-    for i, ch in enumerate(chunks, 1):
-        print(f"\n== Chunk {i} ==")
-        print("Tokens:", ch.tokens)
-        print("Topics:", " / ".join(ch.topics), "Purity:", round(ch.topic_purity, 3))
-        print(ch.text)

+ 22 - 2
applications/config/__init__.py

@@ -1,3 +1,23 @@
-from .model_config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, VLLM_SERVER_URL, DEV_VLLM_SERVER_URL
+from .model_config import (
+    DEFAULT_MODEL,
+    LOCAL_MODEL_CONFIG,
+    VLLM_SERVER_URL,
+    DEV_VLLM_SERVER_URL,
+)
+from .deepseek_config import DEEPSEEK_MODEL, DEEPSEEK_API_KEY
+from .base_chunk import Chunk, ChunkerConfig
+from .milvus_config import MILVUS_CONFIG
+from .mysql_config import RAG_MYSQL_CONFIG
 
-__all__ = ["DEFAULT_MODEL", "LOCAL_MODEL_CONFIG", "VLLM_SERVER_URL", "DEV_VLLM_SERVER_URL"]
+__all__ = [
+    "DEFAULT_MODEL",
+    "LOCAL_MODEL_CONFIG",
+    "VLLM_SERVER_URL",
+    "DEV_VLLM_SERVER_URL",
+    "DEEPSEEK_MODEL",
+    "DEEPSEEK_API_KEY",
+    "Chunk",
+    "ChunkerConfig",
+    "MILVUS_CONFIG",
+    "RAG_MYSQL_CONFIG"
+]

+ 31 - 0
applications/config/base_chunk.py

@@ -0,0 +1,31 @@
+from typing import List, Dict, Any
+from dataclasses import dataclass, field, asdict
+
+@dataclass
+class Chunk:
+    chunk_id: int
+    doc_id: str
+    text: str
+    tokens: int
+    topic: str = ""
+    domain: str = ""
+    task_type: str = ""
+    topic_purity: float = 1.0
+    text_type: int = 1
+    summary: str = ""
+    keywords: List[str] = field(default_factory=list)
+    concepts: List[str] = field(default_factory=list)
+    questions: List[str] = field(default_factory=list)
+    entities: List[str] = field(default_factory=list)
+
+
+@dataclass
+class ChunkerConfig:
+    target_tokens: int = 256
+    boundary_threshold: float = 0.8
+    min_sent_per_chunk: int = 3
+    max_sent_per_chunk: int = 10
+    enable_adaptive_boundary: bool = True
+    enable_kg: bool = True
+    topic_purity_floor: float = 0.8
+    kg_topk: int = 3

+ 7 - 0
applications/config/deepseek_config.py

@@ -0,0 +1,7 @@
+# deepseek official api
+DEEPSEEK_API_KEY = "sk-cfd2df92c8864ab999d66a615ee812c5"
+
+DEEPSEEK_MODEL = {
+    "DeepSeek-R1": "deepseek-reasoner",
+    "DeepSeek-V3": "deepseek-chat",
+}

+ 8 - 0
applications/config/milvus_config.py

@@ -0,0 +1,8 @@
+
+MILVUS_CONFIG = {
+    # "host": "c-981be0ee7225467b-internal.milvus.aliyuncs.com", # 内网
+    "host": "c-981be0ee7225467b.milvus.aliyuncs.com", # 公网
+    "user": "root",
+    "password": "Piaoquan@2025",
+    "port": "19530"
+}

+ 10 - 0
applications/config/mysql_config.py

@@ -0,0 +1,10 @@
+RAG_MYSQL_CONFIG = {
+    "host": "rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com",
+    "user": "wqsd",
+    "password": "wqsd@2025",
+    "port": 3306,
+    "db": "rag",
+    "charset": "utf8mb4",
+    "minsize": 5,
+    "maxsize": 20,
+}

+ 0 - 6
applications/utils/__init__.py

@@ -1,6 +0,0 @@
-from .http import AsyncHttpClient
-from .nlp import SplitTextIntoSentences
-from .nlp import detect_language
-from .nlp import num_tokens
-
-__all__ = ["AsyncHttpClient", "SplitTextIntoSentences", "detect_language", "num_tokens"]

+ 3 - 0
applications/utils/async_utils/__init__.py

@@ -0,0 +1,3 @@
+from .group_task import run_tasks_with_asyncio_task_group
+
+__all__ = ["run_tasks_with_asyncio_task_group"]

+ 51 - 0
applications/utils/async_utils/group_task.py

@@ -0,0 +1,51 @@
+import asyncio
+from typing import Callable, Coroutine, List, Any, Dict
+
+from tqdm.asyncio import tqdm
+
+
+# 使用asyncio.TaskGroup 来高效处理I/O密集型任务
+async def run_tasks_with_asyncio_task_group(
+    task_list: List[Any],
+    handler: Callable[[Any], Coroutine[Any, Any, None]],
+    *,
+    description: str = None,  # 任务介绍
+    unit: str,
+    max_concurrency: int = 20,  # 最大并发数
+    fail_fast: bool = False,  # 是否遇到错误就退出整个tasks
+) -> Dict[str, Any]:
+    """using asyncio.TaskGroup to process I/O-intensive tasks"""
+    if not task_list:
+        return {"total_task": 0, "processed_task": 0, "errors": []}
+
+    processed_task = 0
+    total_task = len(task_list)
+    errors: List[tuple[int, Any, Exception]] = []
+    semaphore = asyncio.Semaphore(max_concurrency)
+    processing_bar = tqdm(total=total_task, unit=unit, desc=description)
+
+    async def _run_single_task(task_obj: Any, idx: int):
+        nonlocal processed_task
+        async with semaphore:
+            try:
+                await handler(task_obj)
+                processed_task += 1
+            except Exception as e:
+                if fail_fast:
+                    raise e
+                errors.append((idx, task_obj, e))
+            finally:
+                processing_bar.update()
+
+    async with asyncio.TaskGroup() as task_group:
+        for index, task in enumerate(task_list, start=1):
+            task_group.create_task(
+                _run_single_task(task, index), name=f"processing {description}-{index}"
+            )
+
+    processing_bar.close()
+    return {
+        "total_task": total_task,
+        "processed_task": processed_task,
+        "errors": errors,
+    }

+ 7 - 0
applications/utils/chunks/__init__.py

@@ -0,0 +1,7 @@
+from .topic_aware_chunking import TopicAwareChunker
+from .llm_classifier import LLMClassifier
+
+__all__ = [
+    "TopicAwareChunker",
+    "LLMClassifier",
+]

+ 2 - 2
applications/chunks/kg_classifier.py → applications/utils/chunks/kg_classifier.py

@@ -43,7 +43,7 @@ class KGClassifier:
         """
         调用 HTTP embedding 服务,返回向量
         """
-        embedding = await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
+        embedding = await get_basic_embedding(text=text, model=DEFAULT_MODEL)
         return np.array(embedding, dtype=np.float32)
 
     async def classify(
@@ -83,4 +83,4 @@ class KGClassifier:
             node = best[0]
 
         purity = float(np.mean(purities)) if purities else 1.0
-        return path, purity
+        return path, purity

+ 59 - 0
applications/utils/chunks/llm_classifier.py

@@ -0,0 +1,59 @@
+from typing import List
+
+from applications.config import Chunk
+from applications.api import fetch_deepseek_completion
+
+
+class LLMClassifier:
+    @staticmethod
+    def generate_prompt(chunk_text: str) -> str:
+        raw_prompt = """
+你是一个文本分析助手。  
+我会给你一段文本,请你输出以下信息:  
+1. **主题标签 (topic)**:一句话概括文本主题  
+2. **关键词 (keywords)**:3-5 个,便于检索  
+3. **摘要 (summary)**:50字以内简要说明  
+4. **领域 (domain)**:该文本所属领域(如:AI 技术、体育、金融)
+5. **任务类型 (task_type)**:文本主要任务类型(如:解释、教学、动作描述、方法提出)  
+6. **核心知识点 (concepts)**:涉及的核心知识点或概念  
+7. **显示/隐式问题 (questions)**:文本中隐含或显式的问题
+8. **实体(entities)**: 文本中的提到的命名实体
+
+请用 JSON 格式输出,例如:
+{
+    "topic": "RAG 技术与分块策略",
+    "summary": "介绍RAG技术并提出主题感知的分块方法。", 
+    "domain": "AI 技术",
+    "task_type": "方法提出",
+    "keywords": ["RAG", "检索增强", "文本分块", "知识图谱"],
+    "concepts": ["RAG", "文本分块", "知识图谱"],
+    "questions": ["如何提升RAG的检索效果?"]
+    "entities": ["entity1"]
+}
+
+下面是文本:
+        """
+        return raw_prompt.strip() + chunk_text
+
+    async def classify_chunk(self, chunk: Chunk) -> Chunk:
+        text = chunk.text.strip()
+        prompt = self.generate_prompt(text)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt, output_type="json"
+        )
+        print(response)
+        return Chunk(
+            chunk_id=chunk.chunk_id,
+            doc_id=chunk.doc_id,
+            text=text,
+            tokens=chunk.tokens,
+            topic_purity=chunk.topic_purity,
+            summary=response.get("summary"),
+            topic=response.get("topic"),
+            domain=response.get("domain"),
+            task_type=response.get("task_type"),
+            concepts=response.get("concepts", []),
+            keywords=response.get("keywords", []),
+            questions=response.get("questions", []),
+            entities=response.get("entities", []),
+        )

+ 197 - 0
applications/utils/chunks/topic_aware_chunking.py

@@ -0,0 +1,197 @@
+"""
+主题感知分块
+"""
+
+from __future__ import annotations
+
+import re
+from typing import List
+
+import numpy as np
+from sklearn.preprocessing import minmax_scale
+
+from applications.api import get_basic_embedding
+from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
+from applications.utils.nlp import SplitTextIntoSentences, num_tokens
+
+# from .llm_classifier import LLMClassifier
+
+
+# sentence boundary strategy
+class BoundaryDetector:
+    def __init__(self, cfg: ChunkerConfig, debug: bool = False):
+        self.cfg = cfg
+        self.debug = debug
+        # 信号增强因子
+        self.signal_boost_turn = 0.20
+        self.signal_boost_fig = 0.20
+        self.min_gap = 1
+
+    @staticmethod
+    def cosine_sim(u: np.ndarray, v: np.ndarray) -> float:
+        """计算余弦相似度"""
+        return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))
+
+    def detect_boundaries(
+        self, sentence_list: List[str], embs: np.ndarray
+    ) -> List[int]:
+        # 1. 相邻句子相似度
+        sims = np.array(
+            [self.cosine_sim(embs[i], embs[i + 1]) for i in range(len(embs) - 1)]
+        )
+        cut_scores = 1 - sims
+
+        # 2. 归一化 cut_scores 到 [0,1]
+        cut_scores = minmax_scale(cut_scores) if len(cut_scores) > 0 else []
+
+        boundaries = []
+        last_boundary = -999
+        for index, base_score in enumerate(cut_scores):
+            sent_to_check = (
+                sentence_list[index]
+                if index < len(sentence_list)
+                else sentence_list[-1]
+            )
+            snippet = sent_to_check[-20:] if sent_to_check else ""
+
+            turn = (
+                self.signal_boost_turn
+                if re.search(
+                    r"(因此|但是|综上|然而|另一方面|In conclusion|However|Therefore)",
+                    snippet,
+                )
+                else 0.0
+            )
+            fig = (
+                self.signal_boost_fig
+                if re.search(
+                    r"(见下图|如表|表\s*\d+|图\s*\d+|Figure|Table)", sent_to_check
+                )
+                else 0.0
+            )
+
+            adj_score = base_score + turn + fig
+
+            if adj_score >= self.cfg.boundary_threshold and (
+                index - last_boundary >= self.min_gap
+            ):
+                boundaries.append(index)
+                last_boundary = index
+
+            # Debug 输出
+            if self.debug:
+                print(
+                    f"[{index}] sim={sims[index]:.3f}, cut={base_score:.3f}, adj={adj_score:.3f}, boundary={index in boundaries}"
+                )
+
+        return boundaries
+
+
+class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
+
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
+    def __init__(self, cfg: ChunkerConfig, doc_id: str):
+        super().__init__(cfg)
+        # self.classifier = LLMClassifier()
+        self.doc_id = doc_id
+
+    @staticmethod
+    async def _encode_batch(texts: List[str]) -> np.ndarray:
+        embs = []
+        for t in texts:
+            e = await get_basic_embedding(t, model=DEFAULT_MODEL)
+            embs.append(np.array(e, dtype=np.float32))
+        return np.stack(embs)
+
+    def _pack_by_boundaries(
+        self, sentence_list: List[str], boundaries: List[int], text_type: int
+    ) -> List[Chunk]:
+        boundary_set = set(boundaries)
+        chunks: List[Chunk] = []
+        start = 0
+        n = len(sentence_list)
+        chunk_id = 0
+        while start < n:
+            end = start
+            sent_count = 0
+            while end < n and sent_count < self.cfg.max_sent_per_chunk:
+                cur_tokens = num_tokens(" ".join(sentence_list[start : end + 1]))
+                sent_count += 1
+                if cur_tokens >= self.cfg.target_tokens:
+                    cut = end
+                    for b in range(end, start - 1, -1):
+                        if b in boundary_set:
+                            cut = b
+                            break
+                    if cut - start + 1 >= self.cfg.min_sent_per_chunk:
+                        end = cut
+                    break
+                end += 1
+
+            text = " ".join(sentence_list[start : end + 1]).strip()
+            tokens = num_tokens(text)
+            chunk_id += 1
+            chunk = Chunk(
+                doc_id=self.doc_id, chunk_id=chunk_id, text=text, tokens=tokens, text_type=text_type
+            )
+            chunks.append(chunk)
+            start = end + 1
+        return chunks
+
+    async def _refine_chunk_by_topic(self, chunk: Chunk) -> List[Chunk]:
+        sentence_list = self.jieba_sent_tokenize(chunk.text)
+        if len(sentence_list) <= self.cfg.min_sent_per_chunk * 2:
+            return [chunk]
+
+        embs = await self._encode_batch(sentence_list)
+        orig = self.cfg.boundary_threshold
+        try:
+            self.cfg.boundary_threshold = max(0.3, orig - 0.1)
+            boundaries = self.detect_boundaries(sentence_list, embs)
+            sub_chunks = self._pack_by_boundaries(sentence_list, boundaries)
+
+            final = []
+            for ch in sub_chunks:
+                topics, purity = await self.kg.classify(ch.text, topk=self.cfg.kg_topk)
+                ch.topics, ch.topic_purity = topics, purity
+                final.append(ch)
+            return final
+        finally:
+            self.cfg.boundary_threshold = orig
+
+    async def chunk(self, text: str, text_type: int) -> List[Chunk]:
+        sentence_list = self.jieba_sent_tokenize(text)
+        if not sentence_list:
+            return []
+
+        sentences_embeddings = await self._encode_batch(sentence_list)
+        boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
+        raw_chunks = self._pack_by_boundaries(sentence_list, boundaries, text_type)
+        return raw_chunks
+
+
+# async def main():
+#     cfg = ChunkerConfig()
+#     sample_text = """
+#         RAG(Retrieval-Augmented Generation)是一种增强生成的技术。
+#         在复杂知识问答中,RAG 通过检索相关文档片段来改善答案质量。
+#         然而,分块策略会显著影响检索召回与可引用性。
+#         因此,我们提出一种主题感知的分块方法,结合 Transformer 边界探测与知识图谱层次分类。
+#         然后,我们讲一个新的主题,篮球
+#         这个也就是罚球动作。一般原地动作分为两种。
+#         第一种原地投篮动作是先下蹲,做好投篮的发力前上举动作,然后竖直向上伸直身体,右臂顺势在身体向上的过程中竖直向上将球向上投出。这种原地投篮的好处是,发力轻松,可以借助身体向上竖直的这个力度的趋势,帮助投篮发力,会让投篮的力气减少很多。尤其是在比赛后半程体力不好的时候,依然可以做到很高的命中略。这种投篮的要领是:主动的竖直向上的意识。我们以前就经常强调竖直起跳和竖直的概念,但是,同样看起来是竖直,但是用出来的效果却很不同,这主要就是技巧的关系了。这个技巧的精髓就在于“主动意识”。在你练习这种投篮的时候,每一次,都要在下蹲以后,明确的在脑子里想着,要竖直向上发力。双腿要竖直向上用力,整个身体也是这样,而且,最为重要的是,你一定要在练习的时候每次都要主动的去想,然后刻意的去竖直向上。这样,长久下去,养成习惯,你的这种投篮才会稳定。这里我们要顺便强调之前的一篇文章,就是录像纠错法,我们这里之所以一再强调要主动意识的竖直上起,就是因为,在录像上,未必能看得出来这个问题。也就是说,你的录像虽然看起来你是竖直起跳的,但是你没有一个主动的也就是刻意的竖直起跳的意识的话,这个球也不是竖直起跳。另外,相反的,如果你在视频上看到自己不是竖直起跳,但是实际上这个球是你使用了竖直起跳的主动意识来发力的。那么,尽管看起来不是很竖直,却依然可以很稳定。也就是说,眼睛会欺骗你,一定要注重你的意识。
+#     """
+#     chunker = TopicAwareChunker(cfg)
+#     chunks = await chunker.chunk(sample_text)
+#
+#     for c in chunks:
+#         print(f"[{c.tokens} tokens] {c.topic} purity={c.topic_purity:.2f}")
+#         print(c.text)
+#
+#
+# if __name__ == "__main__":
+#     asyncio.run(main())

+ 5 - 0
applications/utils/milvus/__init__.py

@@ -0,0 +1,5 @@
+from .collection import milvus_collection
+from .functions import async_insert_chunk, async_search_chunk
+
+
+__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk"]

+ 28 - 0
applications/utils/milvus/collection.py

@@ -0,0 +1,28 @@
+from pymilvus import connections, CollectionSchema, Collection
+from applications.utils.milvus.field import fields
+from applications.config import MILVUS_CONFIG
+
+
+connections.connect("default", **MILVUS_CONFIG)
+
+schema = CollectionSchema(
+    fields, description="Chunk multi-vector embeddings with metadata"
+)
+milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
+
+# create index
+vector_index_params = {
+    "index_type": "IVF_FLAT",
+    "metric_type": "COSINE",
+    "params": {
+        "M": 16, "efConstruction": 200
+    }
+}
+
+milvus_collection.create_index("vector_text", vector_index_params)
+milvus_collection.create_index("vector_summary", vector_index_params)
+milvus_collection.create_index("vector_questions", vector_index_params)
+
+milvus_collection.load()
+
+__all__ = ["milvus_collection"]

+ 70 - 0
applications/utils/milvus/field.py

@@ -0,0 +1,70 @@
+from pymilvus import FieldSchema, DataType
+
+# milvus 向量数据库
+fields = [
+    FieldSchema(
+        name="id",
+        dtype=DataType.INT64,
+        is_primary=True,
+        auto_id=True,
+        description="自增id",
+    ),
+    FieldSchema(
+        name="doc_id", dtype=DataType.VARCHAR, max_length=64, description="文档id"
+    ),
+    FieldSchema(name="chunk_id", dtype=DataType.INT64, description="文档分块id"),
+    # 三种向量字段
+    FieldSchema(name="vector_text", dtype=DataType.FLOAT_VECTOR, dim=2560),
+    FieldSchema(name="vector_summary", dtype=DataType.FLOAT_VECTOR, dim=2560),
+    FieldSchema(name="vector_questions", dtype=DataType.FLOAT_VECTOR, dim=2560),
+    # metadata
+    FieldSchema(
+        name="topic", dtype=DataType.VARCHAR, max_length=255, description="主题"
+    ),
+    FieldSchema(
+        name="domain", dtype=DataType.VARCHAR, max_length=100, description="领域"
+    ),
+    FieldSchema(
+        name="task_type", dtype=DataType.VARCHAR, max_length=100, description="任务类型"
+    ),
+    FieldSchema(
+        name="summary", dtype=DataType.VARCHAR, max_length=512, description="总结"
+    ),
+    FieldSchema(
+        name="keywords",
+        dtype=DataType.ARRAY,
+        element_type=DataType.VARCHAR,
+        max_length=100,
+        max_capacity=5,
+        description="关键词",
+    ),
+    FieldSchema(
+        name="concepts",
+        dtype=DataType.ARRAY,
+        element_type=DataType.VARCHAR,
+        max_length=100,
+        max_capacity=5,
+        description="主要知识点",
+    ),
+    FieldSchema(
+        name="questions",
+        dtype=DataType.ARRAY,
+        element_type=DataType.VARCHAR,
+        max_length=200,
+        max_capacity=5,
+        description="隐含问题",
+    ),
+FieldSchema(
+        name="entities",
+        dtype=DataType.ARRAY,
+        element_type=DataType.VARCHAR,
+        max_length=200,
+        max_capacity=5,
+        description="命名实体",
+    ),
+    FieldSchema(name="topic_purity", dtype=DataType.FLOAT),
+    FieldSchema(name="tokens", dtype=DataType.INT64),
+]
+
+
+__all__ = ["fields"]

+ 34 - 0
applications/utils/milvus/functions.py

@@ -0,0 +1,34 @@
+import asyncio
+from typing import Dict
+
+import pymilvus
+
+
+async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
+    """
+    :param collection:
+    :param data: insert data
+    :return:
+    """
+    res = await asyncio.to_thread(collection.insert, [data])
+    print(res)
+
+
+async def async_search_chunk(
+    collection: pymilvus.Collection, query_vector, params: Dict
+):
+    """
+    :param query_vector: query 向量
+    :param collection:
+    :param params: search 参数
+    :return:
+    """
+    expr = None
+    return await asyncio.to_thread(
+        collection.search,
+        data=[query_vector],
+        param={"metric_type": "COSINE", "params": {"nprobe": 10}},
+        limit=params["limit"],
+        anns_field="vector_text",
+        expr=expr,
+    )

+ 11 - 0
applications/utils/mysql/__init__.py

@@ -0,0 +1,11 @@
+from .pool import DatabaseManager
+from .mapper import Contents, ContentChunks
+
+# 全局数据库管理器实例
+mysql_manager = DatabaseManager()
+
+__all__ = [
+    "mysql_manager",
+    "Contents",
+    "ContentChunks",
+]

+ 102 - 0
applications/utils/mysql/mapper.py

@@ -0,0 +1,102 @@
+import json
+from applications.config import Chunk
+
+
+class TaskConst:
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
+
+class BaseMySQLClient:
+
+    def __init__(self, pool):
+        self.pool = pool
+
+
+class Contents(BaseMySQLClient):
+
+    async def insert_content(self, doc_id, text, text_type):
+        query = """
+            INSERT IGNORE INTO contents
+                (doc_id, text, text_type)
+            VALUES (%s, %s, %s);
+        """
+        return await self.pool.async_save(query=query, params=(doc_id, text, text_type))
+
+    async def update_content_status(self, doc_id, ori_status, new_status):
+        query = """
+            UPDATE contents
+            SET status = %s
+            WHERE doc_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, ori_status)
+        )
+
+
+class ContentChunks(BaseMySQLClient):
+
+    async def insert_chunk(self, chunk: Chunk) -> int:
+        query = """
+            INSERT IGNORE INTO content_chunks
+                (chunk_id, doc_id, text, tokens, topic_purity, text_type) 
+                VALUES (%s, %s, %s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.chunk_id,
+                chunk.doc_id,
+                chunk.text,
+                chunk.tokens,
+                chunk.topic_purity,
+                chunk.text_type
+            ),
+        )
+
+    async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET chunk_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET embedding_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, 
+                keywords = %s, questions = %s, chunk_status = %s, entities = %s
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.summary,
+                chunk.topic,
+                chunk.domain,
+                chunk.task_type,
+                json.dumps(chunk.concepts),
+                json.dumps(chunk.keywords),
+                json.dumps(chunk.questions),
+                new_status,
+                json.dumps(chunk.entities),
+                chunk.doc_id,
+                chunk.chunk_id,
+                ori_status
+            ),
+        )

+ 80 - 0
applications/utils/mysql/pool.py

@@ -0,0 +1,80 @@
+from aiomysql import create_pool
+from aiomysql.cursors import DictCursor
+from applications.config import RAG_MYSQL_CONFIG
+
+
+class DatabaseManager:
+    def __init__(self):
+        self.databases = None
+        self.pools = {}
+
+    async def init_pools(self):
+        # 从配置获取数据库配置,也可以直接在这里配置
+        self.databases = {"rag": RAG_MYSQL_CONFIG}
+
+        for db_name, config in self.databases.items():
+            try:
+                pool = await create_pool(
+                    host=config["host"],
+                    port=config["port"],
+                    user=config["user"],
+                    password=config["password"],
+                    db=config["db"],
+                    minsize=config["minsize"],
+                    maxsize=config["maxsize"],
+                    cursorclass=DictCursor,
+                    autocommit=True,
+                )
+                self.pools[db_name] = pool
+                print(f"Created connection pool for {db_name}")
+            except Exception as e:
+                print(f"Failed to create pool for {db_name}: {str(e)}")
+                self.pools[db_name] = None
+
+    async def close_pools(self):
+        for name, pool in self.pools.items():
+            if pool:
+                pool.close()
+                await pool.wait_closed()
+
+    async def async_fetch(
+        self, query, db_name="rag", params=None, cursor_type=DictCursor
+    ):
+        pool = self.pools[db_name]
+        if not pool:
+            await self.init_pools()
+        # fetch from db
+        try:
+            async with pool.acquire() as conn:
+                async with conn.cursor(cursor_type) as cursor:
+                    await cursor.execute(query, params)
+                    fetch_response = await cursor.fetchall()
+
+            return fetch_response
+        except Exception as e:
+            return None
+
+    async def async_save(self, query, params, db_name="rag", batch: bool = False):
+        pool = self.pools[db_name]
+        if not pool:
+            await self.init_pools()
+
+        async with pool.acquire() as connection:
+            async with connection.cursor() as cursor:
+                try:
+                    if batch:
+                        await cursor.executemany(query, params)
+                    else:
+                        await cursor.execute(query, params)
+                    affected_rows = cursor.rowcount
+                    await connection.commit()
+                    return affected_rows
+                except Exception as e:
+                    await connection.rollback()
+                    raise e
+
+    def get_pool(self, db_name):
+        return self.pools.get(db_name)
+
+    def list_databases(self):
+        return list(self.databases.keys())

+ 0 - 0
applications/vector_database/__init__.py


+ 0 - 17
applications/vector_database/field.py

@@ -1,17 +0,0 @@
-from applications.config import LOCAL_MODEL_CONFIG
-
-from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
-
-
-collections = {}
-for model_name, cfg in LOCAL_MODEL_CONFIG.items():
-    col_name = model_name.replace("/", "_").replace("-", "_").lower()
-    fields = [
-        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
-        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024),
-        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=cfg["dim"]),
-    ]
-    schema = CollectionSchema(fields, description=f"{model_name} embeddings")
-    collection = Collection(col_name, schema=schema)
-    collection.load()
-    collections[model_name] = collection

+ 21 - 7
requirements.txt

@@ -1,7 +1,21 @@
-hypercorn
-quart_cors
-quart
-aiohttp
-pymilvus
-sentence_transformers
-optuna
+aiodns==3.5.0
+aiomysql==0.2.0
+black==25.1.0
+bottleneck==1.4.2
+brotlicffi==1.0.9.2
+datasets==3.3.2
+gmpy2==2.2.1
+jieba==0.42.1
+langdetect==1.0.9
+nltk==3.9.1
+numexpr==2.11.0
+openai==1.107.1
+opentelemetry-api==1.30.0
+optuna==4.5.0
+pip-chill==1.0.3
+pymilvus==2.6.1
+pysocks==1.7.1
+quart-cors==0.8.0
+sentence-transformers==5.1.0
+tiktoken==0.11.0
+uvloop==0.21.0

+ 17 - 3
routes/buleprint.py

@@ -1,13 +1,15 @@
+import uuid
+
 from quart import Blueprint, jsonify, request
 
-from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG
+from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
 from applications.api import get_basic_embedding
-
+from applications.async_task import ChunkEmbeddingTask
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
 
-def server_routes(vector_db):
+def server_routes(mysql_db, vector_db):
 
     @server_bp.route("/embed", methods=["POST"])
     async def embed():
@@ -20,6 +22,18 @@ def server_routes(vector_db):
         embedding = await get_basic_embedding(text, model_name)
         return jsonify({"embedding": embedding})
 
+    @server_bp.route("/chunk", methods=["POST"])
+    async def chunk():
+        body = await request.get_json()
+        text = body.get("text", "")
+        text = text.strip()
+        if not text:
+            return jsonify({"error": "error  text"})
+        doc_id = f"doc-{uuid.uuid4()}"
+        chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id)
+        doc_id = await chunk_task.deal(body)
+        return jsonify({"doc_id": doc_id})
+
     @server_bp.route("/search", methods=["POST"])
     async def search():
         pass

+ 19 - 10
vector_app.py

@@ -1,21 +1,30 @@
+import jieba
 from quart import Quart
-from quart_cors import cors
-
-# from pymilvus import connections
 
 from applications.config import LOCAL_MODEL_CONFIG, DEFAULT_MODEL
+from applications.utils.milvus import milvus_collection
+from applications.utils.mysql import mysql_manager
 from routes import server_routes
 
 app = Quart(__name__)
 
 MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
 
-
-# 连接向量数据库
-# connections.connect("default", host="milvus", port="19530")
-# connections.connect("default", host="milvus", port="19530")
-connections = None
-
 # 注册路由
-app_route = server_routes(connections)
+app_route = server_routes(mysql_manager, milvus_collection)
 app.register_blueprint(app_route)
+
+@app.before_serving
+async def startup():
+    print("Starting application...")
+    await mysql_manager.init_pools()
+    print("Mysql pools init successfully")
+
+    print("Loading jieba dictionary...")
+    jieba.initialize()
+    print("Jieba dictionary loaded successfully")
+
+@app.after_serving
+async def shutdown():
+    print("Shutting down application...")
+    await mysql_manager.close_pools()