Browse Source

v2测试版本(待优化)

xueyiming 2 months ago
parent
commit
4eda9f92a0

+ 51 - 42
api/search.py

@@ -1,15 +1,11 @@
-import json
-from cgitb import reset
 from concurrent.futures import ThreadPoolExecutor
-from typing import List
 
-from fastapi import APIRouter, BackgroundTasks
+from fastapi import APIRouter
 
 from schemas import ResponseWrapper
-from schemas.schemas import Query, ContentData
-from tools_v1 import query_keyword_summary_results, query_keyword_content_results
-from utils.data_utils import add_data
-from utils.deepseek_utils import get_keywords
+from schemas.schemas import ContentParam, DatasetParam
+from service.content_service import get_contents, add_contents, get_content
+from service.dataset_service import get_datasets, add_datasets
 from utils.embedding_utils import get_embedding_content_data
 
 router = APIRouter()
@@ -18,15 +14,12 @@ router = APIRouter()
 executor = ThreadPoolExecutor(max_workers=10)
 
 
-@router.post("/query", response_model=ResponseWrapper)
-async def query_keyword(query: Query):
-    print(query.text)
-    keywords = get_keywords(query.text)['keywords']
-    print(keywords)
-    summary_res = query_keyword_summary_results(keywords)
-    content_res = query_keyword_content_results(keywords)
-    embedding_res = get_embedding_content_data(query.text)
-    res = {'summary_results': summary_res, 'content_results': content_res, 'embedding_results': embedding_res}
+@router.get("/query", response_model=ResponseWrapper)
+async def query_keyword(query, datasetIds):
+    print(query)
+    print(datasetIds)
+    embedding_res = get_embedding_content_data(query, datasetIds.split(','))
+    res = {'results': embedding_res}
     return ResponseWrapper(
         status_code=200,
         detail="success",
@@ -34,34 +27,50 @@ async def query_keyword(query: Query):
     )
 
 
-@router.post("/add/data", response_model=ResponseWrapper)
-async def query_keyword(content_list: List[ContentData]):
-    res_list = []
-    for content in content_list:
-        if content.body_text:
-            print(content.body_text)
-            res = add_data(content.body_text)
-            res_list.append(res)
+@router.get("/content/list", response_model=ResponseWrapper)
+async def content_list(page: int = 1, pageSize: int = 10, datasetId: int = None):
+    data = get_contents(page, pageSize, datasetId)
     return ResponseWrapper(
         status_code=200,
         detail="success",
-        data=res_list
+        data=data
     )
 
-# @router.post("/query/keyword/content", response_model=ResponseWrapper)
-# async def query_keyword(query: Query):
-#     res = query_keyword_content_results(query.text)
-#     return ResponseWrapper(
-#         status_code=200,
-#         detail="success",
-#         data=res
-#     )
 
-# @router.post("/query/embedding", response_model=ResponseWrapper)
-# async def query_keyword(query: Query):
-#     res = query_embedding_results(query.text)
-#     return ResponseWrapper(
-#         status_code=200,
-#         detail="success",
-#         data=res
-#     )
+@router.post("/content/add", response_model=ResponseWrapper)
+async def add_content(content_param: ContentParam):
+    res = add_contents(content_param)
+    return ResponseWrapper(
+        status_code=200,
+        detail="success",
+        data=res
+    )
+
+
+@router.get("/dataset/list", response_model=ResponseWrapper)
+async def dataset_list():
+    data = get_datasets()
+    return ResponseWrapper(
+        status_code=200,
+        detail="success",
+        data=data
+    )
+
+@router.post("/dataset/add", response_model=ResponseWrapper)
+async def get_dataset(dataset_param : DatasetParam):
+    res = add_datasets(dataset_param)
+    return ResponseWrapper(
+        status_code=200,
+        detail="success",
+        data=res
+    )
+
+
+@router.get("/content/get", response_model=ResponseWrapper)
+async def content_get(docId):
+    data = get_content(docId)
+    return ResponseWrapper(
+        status_code=200,
+        detail="success",
+        data=data
+    )

+ 89 - 2
core/database.py

@@ -54,7 +54,6 @@ class DBHelper:
         self.session.rollback()
         logger.error(f"{operation}失败: {error}")
 
-
     def add(self, entity):
         """插入实体对象"""
         try:
@@ -72,7 +71,7 @@ class DBHelper:
         except SQLAlchemyError as e:
             self._handle_error(e, "查询")
 
-    def get_all(self, model, limit=None, **filters):
+    def get_all(self, model, limit=None, order_by=None, **filters):
         """获取所有符合条件的实体对象,支持更复杂的查询条件"""
         try:
             query = self.session.query(model)
@@ -92,6 +91,16 @@ class DBHelper:
             if actual_filters:
                 query = query.filter_by(**actual_filters)
 
+            # 添加排序条件
+            if order_by:
+                # order_by 是一个字典,形如 {'field_name': 'asc' 或 'desc'}
+                for field_name, direction in order_by.items():
+                    field = getattr(model, field_name)
+                    if direction == 'desc':
+                        query = query.order_by(field.desc())
+                    else:
+                        query = query.order_by(field.asc())
+
             # 如果传入了 limit 参数,则限制返回的最大条数
             if limit is not None:
                 query = query.limit(limit)
@@ -102,6 +111,84 @@ class DBHelper:
         except SQLAlchemyError as e:
             self._handle_error(e, "查询")
 
+    def get_paginated(self, model, page=1, page_size=10, order_by=None, **filters):
+        """分页查询符合条件的实体对象,支持排序"""
+        try:
+            query = self.session.query(model)
+
+            # 处理特殊条件如 __in
+            actual_filters = {}
+            for key, value in filters.items():
+                if key.endswith('__in'):
+                    # 处理 IN 查询
+                    field_name = key[:-4]
+                    field = getattr(model, field_name)
+                    query = query.filter(field.in_(value))
+                else:
+                    actual_filters[key] = value
+
+            # 应用其他过滤条件
+            if actual_filters:
+                query = query.filter_by(**actual_filters)
+
+            # 添加排序条件
+            if order_by:
+                # order_by 是一个字典,形如 {'field_name': 'asc' 或 'desc'}
+                for field_name, direction in order_by.items():
+                    field = getattr(model, field_name)
+                    if direction == 'desc':
+                        query = query.order_by(field.desc())
+                    else:
+                        query = query.order_by(field.asc())
+
+            # 计算总记录数
+            total_count = query.count()
+
+            # 分页查询,计算偏移量
+            offset = (page - 1) * page_size
+            query = query.offset(offset).limit(page_size)
+
+            # 执行查询
+            entities = query.all()
+
+            # 返回分页结果:当前页数据和总记录数
+            return {
+                "entities": entities,
+                "total_count": total_count,
+                "page": page,
+                "page_size": page_size,
+                "total_pages": (total_count + page_size - 1) // page_size  # 向上取整计算总页数
+            }
+
+        except SQLAlchemyError as e:
+            self._handle_error(e, "查询")
+
+    def count(self, model, **filters):
+        """查询符合条件的记录条数"""
+        try:
+            query = self.session.query(model)
+
+            # 处理特殊条件如 __in
+            actual_filters = {}
+            for key, value in filters.items():
+                if key.endswith('__in'):
+                    # 处理 IN 查询
+                    field_name = key[:-4]
+                    field = getattr(model, field_name)
+                    query = query.filter(field.in_(value))
+                else:
+                    actual_filters[key] = value
+
+            # 应用其他过滤条件
+            if actual_filters:
+                query = query.filter_by(**actual_filters)
+
+            # 执行查询并获取总记录数
+            count = query.count()
+            return count
+        except SQLAlchemyError as e:
+            self._handle_error(e, "查询条数")
+
     def update(self, model, filters, updates):
         """更新实体对象"""
         try:

+ 0 - 5
data_handle.py

@@ -1,5 +0,0 @@
-from utils.keywords_utils import KeywordSummaryTask
-
-if __name__ == '__main__':
-    keyword_summary_task = KeywordSummaryTask()
-    keyword_summary_task.process_texts_concurrently()

+ 7 - 6
data_models/content_chunks.py

@@ -23,10 +23,11 @@ class ContentChunks(Base):
     questions = Column(Text)
     created_at = Column(TIMESTAMP)
     updated_at = Column(TIMESTAMP)
-    chunk_status = Column(Integer)
-    keywords_status = Column(Integer)
-    embedding_status = Column(Integer)
+    chunk_status = Column(Integer, default=0)
+    es_status = Column(Integer, default=0)
+    embedding_status = Column(Integer, default=0)
     entities = Column(Text)
-    version = Column(Integer)
-
-
+    version = Column(Integer, default=1)
+    text_type = Column(Integer, default=1)
+    dataset_id = Column(Integer)
+    status = Column(Integer, default=1)

+ 21 - 0
data_models/contents.py

@@ -0,0 +1,21 @@
+from sqlalchemy import Column, Text, BigInteger, TIMESTAMP, Integer, Float
+from sqlalchemy.dialects.mysql import VARCHAR, MEDIUMTEXT
+from sqlalchemy.orm import declarative_base
+
+Base = declarative_base()
+
+
+class Contents(Base):
+    __tablename__ = "contents"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    doc_id = Column(VARCHAR(64))
+    title = Column(VARCHAR(255))
+    text = Column(MEDIUMTEXT)
+    author = Column(VARCHAR(100))
+    created_at = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP")
+    updated_at = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP")
+    status = Column(Integer, default=0, nullable=False)
+    doc_status = Column(Integer)
+    text_type = Column(Integer, default=1, nullable=False)
+    dataset_id = Column(Integer)

+ 15 - 0
data_models/dataset.py

@@ -0,0 +1,15 @@
+from sqlalchemy import Column, Text, BigInteger, TIMESTAMP, Integer, Float
+from sqlalchemy.dialects.mysql import VARCHAR, MEDIUMTEXT
+from sqlalchemy.orm import declarative_base
+
+Base = declarative_base()
+
+
+class Dataset(Base):
+    __tablename__ = "dataset"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    name = Column(VARCHAR(64))
+    created_at = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP")
+    updated_at = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP")
+    status = Column(Integer, default=1, nullable=False)

+ 0 - 15
data_models/keyword_clustering.py

@@ -1,15 +0,0 @@
-from sqlalchemy import Column, Text, BigInteger, TIMESTAMP, VARCHAR, ForeignKey
-from sqlalchemy.orm import declarative_base
-
-Base = declarative_base()
-
-
-class KeywordClustering(Base):
-    __tablename__ = "keyword_clustering"
-
-    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
-    keyword_id = Column(BigInteger, nullable=False, comment="关键词id")
-    keyword_summary = Column(Text, nullable=True, comment="关键词知识")
-    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
-                         comment="更新时间")

+ 0 - 12
data_models/keyword_data.py

@@ -1,12 +0,0 @@
-from sqlalchemy import Column, Text, BigInteger, TIMESTAMP, VARCHAR
-from sqlalchemy.orm import declarative_base
-
-Base = declarative_base()
-
-
-class KeywordData(Base):
-    __tablename__ = "keyword_data"
-
-    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
-    keyword = Column(VARCHAR(128), nullable=False, comment="关键词")
-    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")

+ 0 - 14
data_models/keyword_with_content_chunk.py

@@ -1,14 +0,0 @@
-from sqlalchemy import Column, Text, BigInteger, TIMESTAMP, VARCHAR, Integer
-from sqlalchemy.orm import declarative_base
-
-Base = declarative_base()
-
-
-class KeywordWithContentChunk(Base):
-    __tablename__ = "keyword_with_content_chunk"
-
-    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
-    keyword_id = Column(BigInteger, nullable=False, comment="关键词id")
-    content_chunk_id = Column(BigInteger, nullable=False, comment="内容id")
-    keyword_clustering_status = Column(Integer, nullable=False, default=0, comment="总结状态")
-    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")

+ 0 - 1
main.py

@@ -12,7 +12,6 @@ from core.config import logger
 # 导入API路由
 from api.search import router as search_router
 from api.health import router as health_router
-from utils.keywords_utils import KeywordSummaryTask
 
 # 创建 FastAPI 应用
 app = FastAPI(

+ 8 - 2
schemas/schemas.py

@@ -13,5 +13,11 @@ class Query(BaseModel):
     text: str
 
 
-class ContentData(BaseModel):
-    body_text: str
+class ContentParam(BaseModel):
+    datasetId: int
+    title: str
+    text: str
+
+class DatasetParam(BaseModel):
+    name: str
+

+ 0 - 0
service/__init__.py


+ 45 - 0
service/content_service.py

@@ -0,0 +1,45 @@
+import json
+
+import requests
+
+from core.database import DBHelper
+from data_models.contents import Contents
+
+
+def get_contents(page_num, page_size, dataset_id):
+    db_helper = DBHelper()
+    res = db_helper.get_paginated(Contents, page_num, page_size, order_by={'id': 'desc'}, dataset_id=dataset_id)
+    data = []
+    for entity in res["entities"]:
+        data.append({'text': entity.text, 'title': entity.title, 'doc_id': entity.doc_id})
+    res['entities'] = data
+    return res
+
+
+def get_content(doc_id):
+    db_helper = DBHelper()
+    content = db_helper.get(Contents, doc_id=doc_id)
+    data = {'title': content.title, 'text': content.text, 'doc_id': content.doc_id}
+    return data
+
+
+def add_contents(content_param):
+    try:
+        response = requests.post(
+            url='http://61.48.133.26:8001/api/chunk',
+            json={
+                "text": content_param.text,
+                "title": content_param.title,
+                "dataset_id": content_param.datasetId},
+            headers={"Content-Type": "application/json"},
+        )
+        doc_id = response.json()['doc_id']
+        if doc_id:
+            return True
+    except Exception as e:
+        print(e)
+    return False
+
+
+if __name__ == '__main__':
+    print(json.dumps(get_contents(1, 10), ensure_ascii=False))

+ 23 - 0
service/dataset_service.py

@@ -0,0 +1,23 @@
+from datetime import datetime
+
+from core.database import DBHelper
+from data_models.contents import Contents
+from data_models.dataset import Dataset
+
+
+def get_datasets():
+    db_helper = DBHelper()
+    entities = db_helper.get_all(Dataset, status=1)
+    data = []
+    for entity in entities:
+        count = db_helper.count(Contents, dataset_id=entity.id)
+        data.append({'dataset_id': entity.id, 'name': entity.name, 'count': count,
+                     'created_at': entity.created_at.strftime('%Y-%m-%d')})
+    return data
+
+
+def add_datasets(dataset_param):
+    db_helper = DBHelper()
+    dataset = Dataset(name=dataset_param.name, created_at=datetime.now(), updated_at=datetime.now(), status=1)
+    db_helper.add(dataset)
+    return True

+ 51 - 3
utils/data_utils.py

@@ -1,24 +1,57 @@
 import json
 
 import requests
+from openpyxl.styles.builtins import title
 
 from core.config import logger
 from core.database_data import DatabaseHelper
 
 
-def add_data(text):
+def add_data(text, dataset_id, title=None):
     try:
         response = requests.post(
             url='http://61.48.133.26:8001/api/chunk',
             json={
                 "text": text,
-                "text_type": 1},
+                "dataset_id": dataset_id},
             headers={"Content-Type": "application/json"},
         )
         return response.json()['doc_id']
     except Exception as e:
         logger.error(e)
-        return e
+
+
+def is_empty(value):
+    """辅助函数:判断值是否为空(None 或空字符串)"""
+    return value is None or value == ""
+
+
+def parse_json(file_path):
+    text_list = []
+    try:
+        # 读取文件内容
+        with open(file_path, 'r', encoding='utf-8') as file:
+            try:
+                # 解析JSON内容
+                json_data = json.load(file)
+                # 检查是否为JSON数组
+                if isinstance(json_data, list):
+                    # 遍历每个JSON对象
+                    for index, json_obj in enumerate(json_data, 1):
+                        body_text = json_obj.get("body_text", "")
+                        title = json_obj.get("title", "")
+                        if not is_empty(body_text):
+                            text_list.append({'body_text': body_text, 'title': title})
+                else:
+                    print("错误: 文件内容不是一个JSON数组")
+
+            except json.JSONDecodeError as e:
+                print(f"JSON解析错误: {e}")
+    except FileNotFoundError:
+        print(f"错误: 找不到文件 '{file_path}'")
+    except Exception as e:
+        print(f"发生错误: {e}")
+    return text_list
 
 
 def select_data():
@@ -38,3 +71,18 @@ def select_data():
     for row in result:
         add_data(json.loads(row['json_text'])['body_text'])
 
+
+if __name__ == '__main__':
+    json_path = '../data/test_data1.json'
+    text_list = parse_json(json_path)
+    re = []
+    for text in text_list:
+        res = add_data(text['body_text'])
+        if res is None:
+            re.append(text)
+    re1 = []
+    for text in re:
+        res = add_data(text['body_text'])
+        if res is None:
+            re.append(text)
+    print(json.dumps(re1, ensure_ascii=False))

+ 16 - 11
utils/embedding_utils.py

@@ -4,36 +4,41 @@ import requests
 from core.config import logger
 from core.database import DBHelper
 from data_models.content_chunks import ContentChunks
+from data_models.dataset import Dataset
 
 
-def get_embedding_data(query):
+def get_embedding_data(query, dataset_ids, limit=10):
     try:
         response = requests.post(
             url='http://61.48.133.26:8001/api/search',
             json={
-                "query": query,
-                "search_type": "by_vector",
-                "limit": 5},
+                "query_text": query,
+                "search_type": "hybrid",
+                "filters": {
+                    "dataset_id": dataset_ids
+                },
+                "limit": limit},
             headers={"Content-Type": "application/json"},
         )
         return response.json()['results']
     except Exception as e:
         logger.error(e)
-    return []
 
 
-def get_embedding_content_data(query):
+def get_embedding_content_data(query, dataset_ids):
     res = []
     db_helper = DBHelper()
-    results = get_embedding_data(query)
+    results = get_embedding_data(query, dataset_ids)
     if results:
         for result in results:
             content_chunk = db_helper.get(ContentChunks, doc_id=result['doc_id'], chunk_id=result['chunk_id'])
+            dataset = db_helper.get(Dataset, id=content_chunk.dataset_id)
+            dataset_name = None
+            if dataset:
+                dataset_name = dataset.name
             res.append(
-                {'content': content_chunk.text, 'content_summary': content_chunk.summary, 'score': result['score']})
+                {'docId': content_chunk.doc_id, 'content': content_chunk.text,
+                 'contentSummary': content_chunk.summary, 'score': result['score'], 'datasetName': dataset_name})
     return res
 
 
-if __name__ == '__main__':
-    results = get_embedding_content_data("帮我查询一些篮球相关的知识")
-    print(results)

+ 0 - 105
utils/keywords_utils.py

@@ -1,105 +0,0 @@
-import concurrent
-import json
-import threading
-from concurrent.futures import ThreadPoolExecutor
-from time import sleep
-from venv import logger
-
-from core.database import DBHelper
-from data_models.content_chunks import ContentChunks
-from data_models.keyword_clustering import KeywordClustering
-from data_models.keyword_data import KeywordData
-from data_models.keyword_with_content_chunk import KeywordWithContentChunk
-from utils.deepseek_utils import get_keyword_summary, update_keyword_summary
-
-
-class KeywordSummaryTask:
-    lock_dict = {}  # 静态变量,不会随着每次实例化而重置
-
-    def __init__(self):
-        self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='KeywordSummaryTask')
-
-    def get_lock_for_keyword(self, keyword_id):
-        if keyword_id not in self.lock_dict:
-            self.lock_dict[keyword_id] = threading.Lock()
-        return self.lock_dict[keyword_id]
-
-    def _generate_keywords(self, content_chunk):
-        db_helper = DBHelper()
-        keywords = json.loads(content_chunk.keywords)
-        for keyword in keywords:
-            keyword_data = db_helper.get(KeywordData, keyword=keyword)
-            if keyword_data is None:
-                try:
-                    new_keyword_data = KeywordData(keyword=keyword)
-                    keyword_data = db_helper.add(new_keyword_data)
-                except Exception as e:
-                    return
-            keyword_with_content_chunk = db_helper.get(KeywordWithContentChunk, keyword_id=keyword_data.id,
-                                                       content_chunk_id=content_chunk.id)
-            if keyword_with_content_chunk is None:
-                try:
-                    keyword_with_content_chunk = KeywordWithContentChunk(keyword_id=keyword_data.id,
-                                                                         content_chunk_id=content_chunk.id)
-                    db_helper.add(keyword_with_content_chunk)
-                except Exception as e:
-                    return
-            # 获取对应 keyword_id 的锁
-            lock = self.get_lock_for_keyword(keyword_data.id)
-            with lock:
-                if keyword_with_content_chunk.keyword_clustering_status == 0:
-                    try:
-                        keyword_clustering = db_helper.get(KeywordClustering, keyword_id=keyword_data.id)
-                        if keyword_clustering is None:
-                            keyword_summary = get_keyword_summary(content_chunk.text, keyword_data.keyword)
-                            new_keyword_clustering = KeywordClustering(keyword_id=keyword_data.id,
-                                                                       keyword_summary=keyword_summary[
-                                                                           'keyword_summary'])
-                            db_helper.add(new_keyword_clustering)
-                        else:
-                            new_keyword_summary = update_keyword_summary(keyword_clustering.keyword_summary,
-                                                                         keyword,
-                                                                         content_chunk.text)
-                            db_helper.update(KeywordClustering, filters={"id": keyword_clustering.id},
-                                             updates={"keyword_summary": new_keyword_summary})
-                        db_helper.update(KeywordWithContentChunk, filters={"id": keyword_with_content_chunk.id},
-                                         updates={"keyword_clustering_status": 1})
-                    except Exception as e:
-                        db_helper.update(KeywordWithContentChunk, filters={"id": keyword_with_content_chunk.id},
-                                         updates={"keyword_clustering_status": 2})
-
-        db_helper.update(ContentChunks, filters={"id": content_chunk.id},
-                         updates={"keywords_status": 1})
-
-    # 使用线程池处理文本列表
-    def process_texts_concurrently(self):
-        print('process_texts_concurrently start')
-        db_helper = DBHelper()
-        while True:
-            content_chunks = db_helper.get_all(ContentChunks, limit=200, chunk_status=2, keywords_status=0)
-            if len(content_chunks) == 0:
-                logger.info('sleep')
-                print('sleep')
-                sleep(1800)
-            else:
-                future_to_chunk = {self.executor.submit(self._generate_keywords, content_chunk): content_chunk for
-                                   content_chunk
-                                   in
-                                   content_chunks}
-
-                # 等待所有任务完成
-                concurrent.futures.wait(future_to_chunk.keys())
-
-                # 创建一个字典,内容块到结果的映射(注意:这里假设任务没有异常,如果有异常,result()会抛出)
-                results = {}
-                for future, chunk in future_to_chunk.items():
-                    try:
-                        results[chunk] = future.result()
-                    except Exception as exc:
-                        results[chunk] = exc  # 或者你可以选择其他异常处理方式
-                print("success")
-
-
-if __name__ == '__main__':
-    db_helper = DBHelper()
-    print(db_helper.get(KeywordData, keyword='短视频'))