| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- from pymilvus import Collection, connections, utility, FieldSchema, CollectionSchema, DataType
- import pymysql.cursors
- import time
- import requests
- # import pymysql
- from tqdm import tqdm
- from openai import OpenAI
- from typing import List, Dict, Any, Optional
- from cosine_similarity_example import cosine_similarity, cosine_similarity_numpy
- from datetime import datetime, timedelta
- ################################连接milvus数据库 A
- # 配置信息
- MILVUS_CONFIG = {
- "host": "c-981be0ee7225467b.milvus.aliyuncs.com",
- "user": "root",
- "password": "Piaoquan@2025",
- "port": "19530",
- }
- print("正在连接 Milvus 数据库...")
- connections.connect("default", **MILVUS_CONFIG)
- ################################连接milvus数据库 B
- def iterate_milvus_collection(collection, batch_size=1000):
- """遍历Milvus集合的函数"""
- # 获取集合中的实体总数
- total_entities = collection.num_entities
- print(f"集合中总共有 {total_entities} 条记录")
-
- # 分页遍历
- offset = 0
- count = 0
- batch_size =1
- while offset < total_entities:
- # 使用query方法进行分页查询
- # expr="" 表示查询所有记录
- # limit控制每页返回的记录数
- # offset控制从第几条记录开始查询
- results = collection.query(
- expr="",
- limit=batch_size,
- offset=offset,
- output_fields=["id", "doc_id", "chunk_id", "vector_text","vector_summary","vector_questions"] # 指定要返回的字段
- )
-
- # 处理当前批次的结果
- for item in results:
- count += 1
- print(f"记录 {count}: id={item['id']}, doc_id={item['doc_id']}, chunk_id={item['chunk_id']}")
- # print("item['vector_summary'] is ", item['vector_summary'])
- # print("item['vector_questions'] is ", item['vector_questions'])
- # 这里可以根据需要处理每条记录
- # exit()
- # 更新偏移量
- offset += len(results)
-
- # 如果本批次结果为空,说明已经遍历完毕
- if len(results) == 0:
- break
- ################################连接mysql数据库 A
- MYSQL_CONFIG = {
- "host": "rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com",
- "user": "wqsd",
- "password": "wqsd@2025",
- "port": 3306,
- "db": "rag",
- "charset": "utf8mb4",
- }
- class DatabaseManager:
- def __init__(self):
- self.connection = None
-
- def connect(self):
- self.connection = pymysql.connect(
- **MYSQL_CONFIG,
- cursorclass=pymysql.cursors.DictCursor
- )
- print("✅ MySQL connected")
-
- def close(self):
- if self.connection:
- self.connection.close()
- print("MySQL closed")
-
- def fetch(self, query, params=None):
- with self.connection.cursor() as cursor:
- cursor.execute(query, params or ())
- return cursor.fetchall()
- def fetch_one(self, query, params=None):
- with self.connection.cursor() as cursor:
- # print("query is ", query)
- cursor.execute(query)
- return cursor.fetchone()
- def execute_ddl(self, query=None):
- if query is None:
- return False
- with self.connection.cursor() as cursor:
- cursor.execute(query)
- self.connection.commit()
- return True
- ################################
- ###########
- mysql_manager = DatabaseManager()
- mysql_manager.connect()
- ###########
- test_query="set autocommit = 1"
- res = mysql_manager.fetch_one(test_query)
- ################################连接mysql数据库 B
- ################################连接Embedding service A
- VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
- DEFAULT_MODEL = "/models/Qwen3-Embedding-4B"
- def get_basic_embedding(text: str, model=DEFAULT_MODEL):
- """通过HTTP调用在线embedding服务"""
- headers = {
- "Content-Type": "application/json"
- }
- data = {
- "model": model,
- "input": text
- }
-
- response = requests.post(
- VLLM_SERVER_URL,
- headers=headers,
- json=data
- )
-
- if response.status_code == 200:
- result = response.json()
- return result["data"][0]["embedding"]
- else:
- raise Exception(f"Failed to get embedding: {response.status_code} {response.text}")
- ################################连接Embedding service B
- ################################连接deepseek v3 A
- DEEPSEEK_API_KEY = "sk-cfd2df92c8864ab999d66a615ee812c5"
- DEEPSEEK_MODEL = {
- "DeepSeek-R1": "deepseek-reasoner",
- "DeepSeek-V3": "deepseek-chat",
- }
- def get_deepseek_completion(
- model: str,
- prompt: str,
- output_type: str = "text",
- tool_calls: bool = False,
- tools: List[Dict] = None
- ) -> Optional[Dict | List | str]:
- """调用DeepSeek API获取回答"""
- messages = [{"role": "user", "content": prompt}]
- kwargs = {
- "model": DEEPSEEK_MODEL.get(model, "deepseek-chat"),
- "messages": messages,
- }
- # 添加工具调用参数
- if tool_calls and tools:
- kwargs["tools"] = tools
- kwargs["tool_choice"] = "auto"
- # 创建OpenAI客户端连接DeepSeek API
- client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
- # 设置JSON输出格式
- if output_type == "json":
- kwargs["response_format"] = {"type": "json_object"}
- try:
- response = client.chat.completions.create(**kwargs)
- choice = response.choices[0]
- if output_type == "text":
- return choice.message.content # 只返回文本
- elif output_type == "json":
- import json
- return json.loads(choice.message.content)
- else:
- raise ValueError(f"Invalid output_type: {output_type}")
- except Exception as e:
- print(f"[ERROR] DeepSeek API调用失败: {e}")
- return None
- ################################连接deepseek v3 B
- #######################################################
- #######################################################开发
- #######################################################
- # 方法1: 使用utility.describe_collection()查看集合信息
- collection_name = "chunk_multi_embeddings_v2_trsc"
- print(f"\n📋 正在查看集合 '{collection_name}' 的信息...")
- # 正确的方式:创建Collection对象
- milvus_client = Collection(name=collection_name)
- print(milvus_client)
- milvus_client.load()
- total = milvus_client.num_entities
- print("total is ", total)
- ############################遍历mysql中的content_chunks_trsc,获取doc_id和chunk_id,条件为updated_at时间戳小于当天的前一天
- # query = """
- # select doc_id, chunk_id from content_chunks_trsc where updated_at<NOW()- INTERVAL 1 DAY;
- # """
- # results = mysql_manager.fetch(query=query)
- # print(f"获取到{len(results)}条需要处理的记录")
- # print(results[0])
- #############################遍历milvus中的collection chunk_multi_embeddings_v2_trsc 找出所有doc_id,chunk_id 在results 中的记录
- res_id = milvus_client.query(
- expr="",
- output_fields=['id'],
- limit=10000,
- consistency_level="Strong"
- )
- print("res_id length is ", len(res_id)) # 6198
- total = len(res_id)
- # assert len(res_id) == total
- # exit()
- # batch_size=1000
- processed_count = 0
- for i in tqdm(range(0,total)):
- try:
- ####update selelct 数据
- # print("expr is ", f"id == {res_id[i]['id']}")
- # 从错误信息得知id字段是Int64类型,需要作为整数进行比较,不能用引号包裹
- results = milvus_client.query(
- expr=f"id == {res_id[i]['id']}", # 使用双等号,直接使用整数值
- limit=1,
- output_fields=['*'],
- consistency_level="Strong"
- )
- if not results:
- print(f"11111111########id {res_id[i]['id']} 没有查询到结果")
- continue
- ####找外健
- temp_doc_id = results[0]['doc_id']
- temp_chunk_id = results[0]['chunk_id']
- temp_vector_text = results[0]['vector_text']
- temp_vector_summary = results[0]['vector_summary']
- temp_vector_questions = results[0]['vector_questions']
- ####访问mysql,获取updated_at, 如果 updated_at<NOW()- INTERVAL 1 DAY 返回false,则已经处理过,pass,否则正常处理
- query = f"""
- select text,updated_at from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
- """
- results = mysql_manager.fetch_one(query=query)
- if not results:
- print(f"222222222#######doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到结果")
- continue
- temp_updated_at = results['updated_at']
- if temp_updated_at > datetime.now() - timedelta(days=1):
- print(f"doc_id {temp_doc_id} chunk_id {temp_chunk_id} 已经处理过,updated_at {temp_updated_at}")
- continue
- #############
- # exit()
- insert_data = []
- ####找mysql表text
- query = f"""
- select text from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
- """
- results = mysql_manager.fetch_one(query=query)
- if not results:
- print(f"3333333*******doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到text结果")
- continue
- temp_text = results['text']
- # ####访问deepseek
- # """
- # 你是一个专业的文本转写助手,负责将用户输入的文本进行准确的转写。请确保转写结果与原始文本信息保持一致,避免添加任何额外的解释,注释。\
- # 请务必保持语言准确,精炼。以下是信息
- # """
- # ####grok
- # system_pre_prompt = f'''
- # 你是一个专业的文本精炼专家。请将以下原文改写成更简练、精确的版本。要求:
- # 1. 逐句检查原文,保留所有信息、事实、数据、关系和含义,不丢失任何细节。
- # 2. 去除冗余词语、重复表达和不必要的描述,使语言更紧凑和精准。
- # 3. 输出纯文本,不要添加任何格式化元素,如标题、列表、编号、粗体、换行符、额外解释或总结语句。只输出改写后的文本。
- # 4. 如果原文有专业术语,保持原样。
- # 原文:{temp_text}
- # '''
- # trans_answer_grok = get_deepseek_completion(
- # model="DeepSeek-V3", # 使用DeepSeek V3模型
- # prompt=system_pre_prompt, # 使用query_text_1作为prompt
- # output_type="text" # 返回文本格式
- # )
- # print("trans_answer grok is ", trans_answer_grok)
- print("text is", temp_text)
- ###chatgpt
- system_pre_prompt = f'''
- 请将以下文本转写为更简练、精确的表达方式。转写时需要满足以下要求:
- 不丢失信息:转写后的内容必须包含原文中的所有关键信息,任何核心的事实、细节和数据都不能丢失。
- 简洁清晰:去除冗余的词汇和不必要的修饰,确保语言简洁明了,信息传达精准。
- 无额外格式:转写后的文本应为纯文本,不添加任何额外的格式、列表、标点或其他符号,保持原有的结构。
- 请根据这些要求,转写以下段落:{temp_text}
- '''
- ###chatgpt
- trans_answer_gpt = get_deepseek_completion(
- model="DeepSeek-V3", # 使用DeepSeek V3模型
- prompt=system_pre_prompt, # 使用query_text_1作为prompt
- output_type="text" # 返回文本格式
- )
- # print("text is", temp_text)
- print("trans_answer_chatgpt is ", trans_answer_gpt)
- # ###compare
- # system_pre_prompt = f'''
- # 我有两个经过转写后的文本,请你根据以下标准来比较这两个文本,判断哪个文本更好:
- # 1. 信息完整性:哪个文本更好地保留了原文的关键信息?如果有任何重要信息缺失,请指出。
- # 2. 简洁性:哪个文本更加简洁明了,去除了冗余和不必要的部分?
- # 3. 清晰度:哪个文本更容易理解,表达更清晰?
- # 4. 精准度:哪个文本更准确地传达了原文的意思,没有歧义或误解?
- # 请基于这些标准分别对两个文本进行评估,并给出理由说明哪个文本更优。
- # 文本1: {trans_answer_grok}
- # 文本2: {trans_answer_gpt}
- # '''
- # ###compare
- # trans_answer_compare = get_deepseek_completion(
- # model="DeepSeek-V3", # 使用DeepSeek V3模型
- # prompt=system_pre_prompt, # 使用query_text_1作为prompt
- # output_type="text" # 返回文本格式
- # )
- # print("trans_answer_compare is ", trans_answer_compare)
- # exit()
- # ####访问embedding model 获取转写后的text
-
- temp_embedding = get_basic_embedding(text=temp_text, model=DEFAULT_MODEL)
- temp_embedding_trsc = get_basic_embedding(text=trans_answer_gpt, model=DEFAULT_MODEL)
- similarity_score = cosine_similarity(temp_embedding, temp_embedding_trsc)
- print(f"\n 重写后问题与重写前问题的余弦相似度: {cosine_similarity(temp_embedding, temp_embedding_trsc)}")
- # 更新mysql表,将转写后的文本、相似度分数和当前时间戳更新到对应记录中
- # test_query="select @@autocommit"
- # res = mysql_manager.fetch_one(test_query)
- # print("autocommit is ", res['@@autocommit'])
- # exit()
- update_query = f"""
- update content_chunks_trsc set simillarity = {similarity_score}, text = '{trans_answer_gpt}', updated_at = NOW() where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
- """
- mysql_manager.fetch_one(update_query)
- ####更新milvus 的collection,将转写后的embedding替代原来的embedding
- # 更新milvus的collection chunk_multi_embeddings_v2_trsc,使用转写过的text计算出的embedding temp_embedding_trsc 更新其vector_text
- # Milvus的upsert方法要求数据格式为字段名到值列表的映射
- # 注意:必须包含所有非空且无默认值的字段
- upsert_expr = f"id == {res_id[i]['id']}"
- # 正确格式:包含所有必需字段
- upsert_data = {
- "id": res_id[i]['id'], # 注意:upsert操作需要将值放入列表中
- "doc_id": temp_doc_id, # 添加必需的doc_id字段
- "chunk_id": temp_chunk_id, # 添加必需的chunk_id字段
- "vector_text": temp_embedding_trsc, # vector_text字段也需要放入列表中
- "vector_summary":temp_vector_summary,
- "vector_questions":temp_vector_questions
- }
- # 执行更新操作
- milvus_client.upsert(upsert_data, expr=upsert_expr) # 注意参数顺序:先数据后表达式
- processed_count += 1
- ##################
- if processed_count % 200 == 0:
- milvus_client.flush()
- except Exception as e:
- print(f"处理第{i}条数据时出现异常: {str(e)}")
- # 跳过当前循环,继续处理下一条数据
- continue
- milvus_client.flush()
|