فهرست منبع

新增subtask_decompose.py分解query;transcription 转写chunk

lookathis@163.com 2 ماه پیش
والد
کامیت
742b7a2d21
2فایلهای تغییر یافته به همراه536 افزوده شده و 0 حذف شده
  1. 145 0
      transcription/subtask_decompose.py
  2. 391 0
      transcription/transcription.py

+ 145 - 0
transcription/subtask_decompose.py

@@ -0,0 +1,145 @@
+# -*- codong: utf-8 -*-
+"""问题分解与分步处理工具
+
+此脚本用于将复杂问题分解为多个步骤,并对每个步骤进行单独处理,最终汇总成完整答案。
+主要通过调用DeepSeek API实现问题分解和分步解答功能。
+
+功能流程:
+1. 配置DeepSeek API参数
+2. 定义API调用函数
+3. 接收原始问题并构建系统提示
+4. 调用API分解问题为多个步骤
+5. 逐一处理每个步骤并获取回答
+6. 汇总所有信息生成完整结果
+"""
+
+from openai import OpenAI
+from typing import List, Dict, Any, Optional
+
+from cosine_similarity_example import cosine_similarity, cosine_similarity_numpy
+# 导入各种可能需要的库(部分未在当前版本使用,但为扩展预留)ilarity_numpy
+# 配置信息
+# DeepSeek 配置
+DEEPSEEK_API_KEY = "sk-cfd2df92c8864ab999d66a615ee812c5"
+DEEPSEEK_MODEL = {
+    "DeepSeek-R1": "deepseek-reasoner",
+    "DeepSeek-V3": "deepseek-chat",
+}
+def get_deepseek_completion(
+    model: str, 
+    prompt: str,
+    output_type: str = "text",
+    tool_calls: bool = False,
+    tools: List[Dict] = None
+) -> Optional[Dict | List | str]:
+    """调用DeepSeek API获取回答"""
+    messages = [{"role": "user", "content": prompt}]
+    kwargs = {
+        "model": DEEPSEEK_MODEL.get(model, "deepseek-chat"),
+        "messages": messages,
+    }
+
+    # 添加工具调用参数
+    if tool_calls and tools:
+        kwargs["tools"] = tools
+        kwargs["tool_choice"] = "auto"
+
+    # 创建OpenAI客户端连接DeepSeek API
+    client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
+
+    # 设置JSON输出格式
+    if output_type == "json":
+        kwargs["response_format"] = {"type": "json_object"}
+
+    try:
+        response = client.chat.completions.create(**kwargs)
+        choice = response.choices[0]
+
+        if output_type == "text":
+            return choice.message.content  # 只返回文本
+        elif output_type == "json":
+            import json
+            return json.loads(choice.message.content)
+        else:
+            raise ValueError(f"Invalid output_type: {output_type}")
+
+    except Exception as e:
+        print(f"[ERROR] DeepSeek API调用失败: {e}")
+        return None
+
+def main():
+    bias_prompt ="请尽量简洁准确地回答:"
+    orig_query = "如何养一只金刚鹦鹉"
+    system_prompt = f"""你是一个智能助手,负责判断问题的复杂度并给出相应处理。
+    如果是简单问题,直接组织为json格式的回答返回该回答,例如:
+    {{"steps": ["简单问题的回答"]}}。
+    如果是复杂问题,将这个问题分解为多个步骤,步骤的组织为json格式,格式为{{
+        "steps": [
+            "步骤1",
+            "步骤2",
+            ...
+        ]
+    }}
+    例如:如何写一篇好的博文
+    1.确定博文主题和目标受众;
+    2.进金刚鹦鹉集和研究;
+    3.制定博文大纲;
+    ...
+    现在问题如下:{orig_query}
+    """
+
+    query = system_prompt
+    print(f"\n调用DeepSeek回答: {query}")
+
+    answer = get_deepseek_completion(
+        model="DeepSeek-V3",  # 使用DeepSeek V3模型
+        prompt=query,  # 使用query_text_1作为prompt
+        output_type="json"    # 返回文本格式
+    )
+
+    ############# ####解析步骤
+    if answer:
+        print(f"\nDeepSeek回答 分解子问题为:\n{answer}")
+    else:
+        print("未获取到DeepSeek的回答")
+
+    steps = answer.get('steps', [])
+    
+    print(f"\n解析出的步骤数量: {len(steps)}")
+    for i, step in enumerate(steps, 1):
+        print(f"步骤{i}: {step}")
+
+    ##################处理每个步骤
+    context = orig_query
+    for i, step in enumerate(steps, 1):
+        print(f"\n处理步骤{i}: {step}")
+        context += f"\n步骤{i}: {step}"
+        # 调用DeepSeek处理每个步骤
+        '''
+        这个步骤过程有选择:
+        1.调用使用RAG模型,调用知识库,获取步骤的回答
+        2.将context信息作为上下文一并投入prompt,获取步骤的回答
+        '''        
+        step_answer = get_deepseek_completion(
+            model="DeepSeek-V3",  # 使用DeepSeek V3模型
+            prompt=bias_prompt + step,  
+            output_type="text"    # 返回文本格式
+        )
+        '''
+        这个步骤过程有选择:
+        1.调用使用RAG模型,调用知识库,获取步骤的回答
+        2.将context信息作为上下文一并投入prompt,获取步骤的回答
+        '''        
+        if step_answer:
+            print(f"步骤{i}的回答: {step_answer}")
+            context += f"\n步骤{i}的回答: {step_answer}"
+        else:
+            print(f"步骤{i}未获取到回答")
+
+
+    print("结果:", context)
+
+  
+
+if __name__ == "__main__":
+    main()

+ 391 - 0
transcription/transcription.py

@@ -0,0 +1,391 @@
+from pymilvus import Collection, connections, utility, FieldSchema, CollectionSchema, DataType
+import pymysql.cursors
+import time
+import requests
+
+# import pymysql
+from tqdm import tqdm 
+from openai import OpenAI
+from typing import List, Dict, Any, Optional
+from cosine_similarity_example import cosine_similarity, cosine_similarity_numpy
+from datetime import datetime, timedelta
+################################连接milvus数据库 A
+# 配置信息
+MILVUS_CONFIG = {
+    "host": "c-981be0ee7225467b.milvus.aliyuncs.com",
+    "user": "root",
+    "password": "Piaoquan@2025",
+    "port": "19530",
+}
+print("正在连接 Milvus 数据库...")
+connections.connect("default", **MILVUS_CONFIG)
+################################连接milvus数据库 B
+
+def iterate_milvus_collection(collection, batch_size=1000):
+    """遍历Milvus集合的函数"""
+    # 获取集合中的实体总数
+    total_entities = collection.num_entities
+    print(f"集合中总共有 {total_entities} 条记录")
+    
+    # 分页遍历
+    offset = 0
+    count = 0
+    batch_size =1
+
+    while offset < total_entities:
+        # 使用query方法进行分页查询
+        # expr="" 表示查询所有记录
+        # limit控制每页返回的记录数
+        # offset控制从第几条记录开始查询
+        results = collection.query(
+            expr="",
+            limit=batch_size,
+            offset=offset,
+            output_fields=["id", "doc_id", "chunk_id", "vector_text","vector_summary","vector_questions"]  # 指定要返回的字段
+        )
+        
+        # 处理当前批次的结果
+        for item in results:
+            count += 1
+            print(f"记录 {count}: id={item['id']}, doc_id={item['doc_id']}, chunk_id={item['chunk_id']}")
+            # print("item['vector_summary'] is ", item['vector_summary'])
+            # print("item['vector_questions'] is ", item['vector_questions'])
+            # 这里可以根据需要处理每条记录
+        # exit()
+        # 更新偏移量
+        offset += len(results)
+        
+        # 如果本批次结果为空,说明已经遍历完毕
+        if len(results) == 0:
+            break
+
+################################连接mysql数据库 A
+MYSQL_CONFIG = {
+    "host": "rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com",
+    "user": "wqsd",
+    "password": "wqsd@2025",
+    "port": 3306,
+    "db": "rag",
+    "charset": "utf8mb4",
+}
+class DatabaseManager:
+    def __init__(self):
+        self.connection = None
+    
+    def connect(self):
+        self.connection = pymysql.connect(
+            **MYSQL_CONFIG,
+            cursorclass=pymysql.cursors.DictCursor
+        )
+        print("✅ MySQL connected")
+    
+    def close(self):
+        if self.connection:
+            self.connection.close()
+            print("MySQL closed")
+    
+    def fetch(self, query, params=None):
+        with self.connection.cursor() as cursor:
+            cursor.execute(query, params or ())
+            return cursor.fetchall()
+
+    def fetch_one(self, query, params=None):
+        with self.connection.cursor() as cursor:
+            # print("query is ", query)
+            cursor.execute(query)
+            return cursor.fetchone()
+    def execute_ddl(self, query=None):
+        if query is None:
+            return False
+        with self.connection.cursor() as cursor:
+            cursor.execute(query)
+            self.connection.commit()
+            return True
+################################
+
+###########
+mysql_manager = DatabaseManager()
+mysql_manager.connect()
+###########
+
+test_query="set autocommit = 1"
+res = mysql_manager.fetch_one(test_query)
+################################连接mysql数据库 B
+
+################################连接Embedding service A
+VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
+DEFAULT_MODEL = "/models/Qwen3-Embedding-4B"
+
+def get_basic_embedding(text: str, model=DEFAULT_MODEL):
+    """通过HTTP调用在线embedding服务"""
+    headers = {
+        "Content-Type": "application/json"
+    }
+    data = {
+        "model": model,
+        "input": text
+    }
+    
+    response = requests.post(
+        VLLM_SERVER_URL,
+        headers=headers,
+        json=data
+    )
+    
+    if response.status_code == 200:
+        result = response.json()
+        return result["data"][0]["embedding"]
+    else:
+        raise Exception(f"Failed to get embedding: {response.status_code} {response.text}")
+################################连接Embedding service B
+
+################################连接deepseek v3 A
+DEEPSEEK_API_KEY = "sk-cfd2df92c8864ab999d66a615ee812c5"
+DEEPSEEK_MODEL = {
+    "DeepSeek-R1": "deepseek-reasoner",
+    "DeepSeek-V3": "deepseek-chat",
+}
+
+def get_deepseek_completion(
+    model: str, 
+    prompt: str,
+    output_type: str = "text",
+    tool_calls: bool = False,
+    tools: List[Dict] = None
+) -> Optional[Dict | List | str]:
+    """调用DeepSeek API获取回答"""
+    messages = [{"role": "user", "content": prompt}]
+    kwargs = {
+        "model": DEEPSEEK_MODEL.get(model, "deepseek-chat"),
+        "messages": messages,
+    }
+    # 添加工具调用参数
+    if tool_calls and tools:
+        kwargs["tools"] = tools
+        kwargs["tool_choice"] = "auto"
+
+    # 创建OpenAI客户端连接DeepSeek API
+    client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
+
+    # 设置JSON输出格式
+    if output_type == "json":
+        kwargs["response_format"] = {"type": "json_object"}
+
+    try:
+        response = client.chat.completions.create(**kwargs)
+        choice = response.choices[0]
+
+        if output_type == "text":
+            return choice.message.content  # 只返回文本
+        elif output_type == "json":
+            import json
+            return json.loads(choice.message.content)
+        else:
+            raise ValueError(f"Invalid output_type: {output_type}")
+
+    except Exception as e:
+        print(f"[ERROR] DeepSeek API调用失败: {e}")
+        return None
+################################连接deepseek v3 B
+
+#######################################################
+#######################################################开发
+#######################################################
+
+# 方法1: 使用utility.describe_collection()查看集合信息
+collection_name = "chunk_multi_embeddings_v2_trsc"
+print(f"\n📋 正在查看集合 '{collection_name}' 的信息...")
+
+# 正确的方式:创建Collection对象
+milvus_client = Collection(name=collection_name)
+print(milvus_client)
+milvus_client.load()
+total = milvus_client.num_entities
+print("total is ", total)
+
+############################遍历mysql中的content_chunks_trsc,获取doc_id和chunk_id,条件为updated_at时间戳小于当天的前一天
+
+# query = """
+#     select doc_id, chunk_id from content_chunks_trsc where updated_at<NOW()- INTERVAL 1 DAY;
+# """
+# results = mysql_manager.fetch(query=query)
+# print(f"获取到{len(results)}条需要处理的记录")
+# print(results[0])
+
+#############################遍历milvus中的collection chunk_multi_embeddings_v2_trsc 找出所有doc_id,chunk_id 在results 中的记录
+res_id = milvus_client.query(
+        expr="",
+        output_fields=['id'],
+        limit=10000,
+        consistency_level="Strong"
+    )
+print("res_id length is ", len(res_id)) # 6198
+
+total = len(res_id)
+# assert len(res_id) == total
+
+# exit()
+# batch_size=1000
+processed_count = 0
+
+for i in tqdm(range(0,total)):
+    try:
+        ####update selelct 数据
+        # print("expr is ", f"id == {res_id[i]['id']}")
+        # 从错误信息得知id字段是Int64类型,需要作为整数进行比较,不能用引号包裹
+        results = milvus_client.query(
+            expr=f"id == {res_id[i]['id']}",  # 使用双等号,直接使用整数值
+            limit=1,
+            output_fields=['*'],
+            consistency_level="Strong"
+        )
+        if not results:
+            print(f"11111111########id {res_id[i]['id']} 没有查询到结果")
+            continue
+        ####找外健
+        temp_doc_id = results[0]['doc_id']
+        temp_chunk_id = results[0]['chunk_id']
+        temp_vector_text = results[0]['vector_text']
+        temp_vector_summary = results[0]['vector_summary'] 
+        temp_vector_questions = results[0]['vector_questions']  
+
+        ####访问mysql,获取updated_at, 如果 updated_at<NOW()- INTERVAL 1 DAY 返回false,则已经处理过,pass,否则正常处理
+        query = f"""
+            select text,updated_at from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
+        """
+        results = mysql_manager.fetch_one(query=query)
+        if not results:
+            print(f"222222222#######doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到结果")
+            continue
+        temp_updated_at = results['updated_at']
+        if temp_updated_at > datetime.now() - timedelta(days=1):
+            print(f"doc_id {temp_doc_id} chunk_id {temp_chunk_id} 已经处理过,updated_at {temp_updated_at}")
+            continue
+        #############
+        # exit()
+        insert_data = []
+
+        ####找mysql表text
+        query = f"""
+            select text from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
+        """
+        results = mysql_manager.fetch_one(query=query)
+
+        if not results:
+            print(f"3333333*******doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到text结果")
+            continue
+        temp_text = results['text']
+
+        # ####访问deepseek
+        # """
+        # 你是一个专业的文本转写助手,负责将用户输入的文本进行准确的转写。请确保转写结果与原始文本信息保持一致,避免添加任何额外的解释,注释。\
+        # 请务必保持语言准确,精炼。以下是信息
+        # """
+        # ####grok
+        # system_pre_prompt = f'''
+        # 你是一个专业的文本精炼专家。请将以下原文改写成更简练、精确的版本。要求:
+        # 1. 逐句检查原文,保留所有信息、事实、数据、关系和含义,不丢失任何细节。
+        # 2. 去除冗余词语、重复表达和不必要的描述,使语言更紧凑和精准。
+        # 3. 输出纯文本,不要添加任何格式化元素,如标题、列表、编号、粗体、换行符、额外解释或总结语句。只输出改写后的文本。
+        # 4. 如果原文有专业术语,保持原样。
+
+        # 原文:{temp_text}
+        # '''
+
+        # trans_answer_grok = get_deepseek_completion(
+        #     model="DeepSeek-V3",  # 使用DeepSeek V3模型
+        #     prompt=system_pre_prompt,  # 使用query_text_1作为prompt
+        #     output_type="text"    # 返回文本格式
+        # )
+        # print("trans_answer grok is ", trans_answer_grok)
+
+        print("text is", temp_text)
+
+        ###chatgpt
+        system_pre_prompt = f'''
+        请将以下文本转写为更简练、精确的表达方式。转写时需要满足以下要求:
+        不丢失信息:转写后的内容必须包含原文中的所有关键信息,任何核心的事实、细节和数据都不能丢失。
+        简洁清晰:去除冗余的词汇和不必要的修饰,确保语言简洁明了,信息传达精准。
+        无额外格式:转写后的文本应为纯文本,不添加任何额外的格式、列表、标点或其他符号,保持原有的结构。
+        请根据这些要求,转写以下段落:{temp_text}
+        '''
+        ###chatgpt
+
+        trans_answer_gpt = get_deepseek_completion(
+            model="DeepSeek-V3",  # 使用DeepSeek V3模型
+            prompt=system_pre_prompt,  # 使用query_text_1作为prompt
+            output_type="text"    # 返回文本格式
+        )
+        # print("text is", temp_text)
+        print("trans_answer_chatgpt is ", trans_answer_gpt)
+
+        # ###compare
+        # system_pre_prompt = f'''
+        # 我有两个经过转写后的文本,请你根据以下标准来比较这两个文本,判断哪个文本更好:
+        # 1. 信息完整性:哪个文本更好地保留了原文的关键信息?如果有任何重要信息缺失,请指出。
+        # 2. 简洁性:哪个文本更加简洁明了,去除了冗余和不必要的部分?
+        # 3. 清晰度:哪个文本更容易理解,表达更清晰?
+        # 4. 精准度:哪个文本更准确地传达了原文的意思,没有歧义或误解?
+        # 请基于这些标准分别对两个文本进行评估,并给出理由说明哪个文本更优。
+        # 文本1: {trans_answer_grok}
+        # 文本2: {trans_answer_gpt}
+        # '''
+        # ###compare
+        # trans_answer_compare = get_deepseek_completion(
+        #     model="DeepSeek-V3",  # 使用DeepSeek V3模型
+        #     prompt=system_pre_prompt,  # 使用query_text_1作为prompt
+        #     output_type="text"    # 返回文本格式
+        # )
+        # print("trans_answer_compare is ", trans_answer_compare)
+
+        # exit()
+        # ####访问embedding model 获取转写后的text
+     
+        temp_embedding = get_basic_embedding(text=temp_text, model=DEFAULT_MODEL)
+        temp_embedding_trsc = get_basic_embedding(text=trans_answer_gpt, model=DEFAULT_MODEL)
+
+        similarity_score = cosine_similarity(temp_embedding, temp_embedding_trsc)
+
+        print(f"\n 重写后问题与重写前问题的余弦相似度: {cosine_similarity(temp_embedding, temp_embedding_trsc)}")
+
+        # 更新mysql表,将转写后的文本、相似度分数和当前时间戳更新到对应记录中
+
+        # test_query="select @@autocommit"
+        # res = mysql_manager.fetch_one(test_query)
+        # print("autocommit is ", res['@@autocommit'])
+        # exit()
+        update_query = f"""
+            update content_chunks_trsc set simillarity = {similarity_score}, text = '{trans_answer_gpt}', updated_at = NOW() where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
+        """
+        mysql_manager.fetch_one(update_query)
+
+        ####更新milvus 的collection,将转写后的embedding替代原来的embedding
+        # 更新milvus的collection chunk_multi_embeddings_v2_trsc,使用转写过的text计算出的embedding temp_embedding_trsc 更新其vector_text
+        # Milvus的upsert方法要求数据格式为字段名到值列表的映射
+        # 注意:必须包含所有非空且无默认值的字段
+        upsert_expr = f"id == {res_id[i]['id']}"
+        # 正确格式:包含所有必需字段
+        upsert_data = {
+            "id": res_id[i]['id'],  # 注意:upsert操作需要将值放入列表中
+            "doc_id": temp_doc_id,  # 添加必需的doc_id字段
+            "chunk_id": temp_chunk_id,  # 添加必需的chunk_id字段
+            "vector_text": temp_embedding_trsc, # vector_text字段也需要放入列表中
+            "vector_summary":temp_vector_summary,
+            "vector_questions":temp_vector_questions
+        }
+        # 执行更新操作
+        milvus_client.upsert(upsert_data, expr=upsert_expr)  # 注意参数顺序:先数据后表达式
+
+        processed_count += 1
+        ##################
+        if processed_count % 200 == 0:
+            milvus_client.flush()
+    except Exception as e:
+        print(f"处理第{i}条数据时出现异常: {str(e)}")
+        # 跳过当前循环,继续处理下一条数据
+        continue
+milvus_client.flush()
+
+
+
+