transcription.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. from pymilvus import Collection, connections, utility, FieldSchema, CollectionSchema, DataType
  2. import pymysql.cursors
  3. import time
  4. import requests
  5. # import pymysql
  6. from tqdm import tqdm
  7. from openai import OpenAI
  8. from typing import List, Dict, Any, Optional
  9. from cosine_similarity_example import cosine_similarity, cosine_similarity_numpy
  10. from datetime import datetime, timedelta
  11. ################################连接milvus数据库 A
  12. # 配置信息
  13. MILVUS_CONFIG = {
  14. "host": "c-981be0ee7225467b.milvus.aliyuncs.com",
  15. "user": "root",
  16. "password": "Piaoquan@2025",
  17. "port": "19530",
  18. }
  19. print("正在连接 Milvus 数据库...")
  20. connections.connect("default", **MILVUS_CONFIG)
  21. ################################连接milvus数据库 B
  22. def iterate_milvus_collection(collection, batch_size=1000):
  23. """遍历Milvus集合的函数"""
  24. # 获取集合中的实体总数
  25. total_entities = collection.num_entities
  26. print(f"集合中总共有 {total_entities} 条记录")
  27. # 分页遍历
  28. offset = 0
  29. count = 0
  30. batch_size =1
  31. while offset < total_entities:
  32. # 使用query方法进行分页查询
  33. # expr="" 表示查询所有记录
  34. # limit控制每页返回的记录数
  35. # offset控制从第几条记录开始查询
  36. results = collection.query(
  37. expr="",
  38. limit=batch_size,
  39. offset=offset,
  40. output_fields=["id", "doc_id", "chunk_id", "vector_text","vector_summary","vector_questions"] # 指定要返回的字段
  41. )
  42. # 处理当前批次的结果
  43. for item in results:
  44. count += 1
  45. print(f"记录 {count}: id={item['id']}, doc_id={item['doc_id']}, chunk_id={item['chunk_id']}")
  46. # print("item['vector_summary'] is ", item['vector_summary'])
  47. # print("item['vector_questions'] is ", item['vector_questions'])
  48. # 这里可以根据需要处理每条记录
  49. # exit()
  50. # 更新偏移量
  51. offset += len(results)
  52. # 如果本批次结果为空,说明已经遍历完毕
  53. if len(results) == 0:
  54. break
  55. ################################连接mysql数据库 A
  56. MYSQL_CONFIG = {
  57. "host": "rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com",
  58. "user": "wqsd",
  59. "password": "wqsd@2025",
  60. "port": 3306,
  61. "db": "rag",
  62. "charset": "utf8mb4",
  63. }
  64. class DatabaseManager:
  65. def __init__(self):
  66. self.connection = None
  67. def connect(self):
  68. self.connection = pymysql.connect(
  69. **MYSQL_CONFIG,
  70. cursorclass=pymysql.cursors.DictCursor
  71. )
  72. print("✅ MySQL connected")
  73. def close(self):
  74. if self.connection:
  75. self.connection.close()
  76. print("MySQL closed")
  77. def fetch(self, query, params=None):
  78. with self.connection.cursor() as cursor:
  79. cursor.execute(query, params or ())
  80. return cursor.fetchall()
  81. def fetch_one(self, query, params=None):
  82. with self.connection.cursor() as cursor:
  83. # print("query is ", query)
  84. cursor.execute(query)
  85. return cursor.fetchone()
  86. def execute_ddl(self, query=None):
  87. if query is None:
  88. return False
  89. with self.connection.cursor() as cursor:
  90. cursor.execute(query)
  91. self.connection.commit()
  92. return True
  93. ################################
  94. ###########
  95. mysql_manager = DatabaseManager()
  96. mysql_manager.connect()
  97. ###########
  98. test_query="set autocommit = 1"
  99. res = mysql_manager.fetch_one(test_query)
  100. ################################连接mysql数据库 B
  101. ################################连接Embedding service A
  102. VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
  103. DEFAULT_MODEL = "/models/Qwen3-Embedding-4B"
  104. def get_basic_embedding(text: str, model=DEFAULT_MODEL):
  105. """通过HTTP调用在线embedding服务"""
  106. headers = {
  107. "Content-Type": "application/json"
  108. }
  109. data = {
  110. "model": model,
  111. "input": text
  112. }
  113. response = requests.post(
  114. VLLM_SERVER_URL,
  115. headers=headers,
  116. json=data
  117. )
  118. if response.status_code == 200:
  119. result = response.json()
  120. return result["data"][0]["embedding"]
  121. else:
  122. raise Exception(f"Failed to get embedding: {response.status_code} {response.text}")
  123. ################################连接Embedding service B
  124. ################################连接deepseek v3 A
  125. DEEPSEEK_API_KEY = "sk-cfd2df92c8864ab999d66a615ee812c5"
  126. DEEPSEEK_MODEL = {
  127. "DeepSeek-R1": "deepseek-reasoner",
  128. "DeepSeek-V3": "deepseek-chat",
  129. }
  130. def get_deepseek_completion(
  131. model: str,
  132. prompt: str,
  133. output_type: str = "text",
  134. tool_calls: bool = False,
  135. tools: List[Dict] = None
  136. ) -> Optional[Dict | List | str]:
  137. """调用DeepSeek API获取回答"""
  138. messages = [{"role": "user", "content": prompt}]
  139. kwargs = {
  140. "model": DEEPSEEK_MODEL.get(model, "deepseek-chat"),
  141. "messages": messages,
  142. }
  143. # 添加工具调用参数
  144. if tool_calls and tools:
  145. kwargs["tools"] = tools
  146. kwargs["tool_choice"] = "auto"
  147. # 创建OpenAI客户端连接DeepSeek API
  148. client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
  149. # 设置JSON输出格式
  150. if output_type == "json":
  151. kwargs["response_format"] = {"type": "json_object"}
  152. try:
  153. response = client.chat.completions.create(**kwargs)
  154. choice = response.choices[0]
  155. if output_type == "text":
  156. return choice.message.content # 只返回文本
  157. elif output_type == "json":
  158. import json
  159. return json.loads(choice.message.content)
  160. else:
  161. raise ValueError(f"Invalid output_type: {output_type}")
  162. except Exception as e:
  163. print(f"[ERROR] DeepSeek API调用失败: {e}")
  164. return None
  165. ################################连接deepseek v3 B
  166. #######################################################
  167. #######################################################开发
  168. #######################################################
  169. # 方法1: 使用utility.describe_collection()查看集合信息
  170. collection_name = "chunk_multi_embeddings_v2_trsc"
  171. print(f"\n📋 正在查看集合 '{collection_name}' 的信息...")
  172. # 正确的方式:创建Collection对象
  173. milvus_client = Collection(name=collection_name)
  174. print(milvus_client)
  175. milvus_client.load()
  176. total = milvus_client.num_entities
  177. print("total is ", total)
  178. ############################遍历mysql中的content_chunks_trsc,获取doc_id和chunk_id,条件为updated_at时间戳小于当天的前一天
  179. # query = """
  180. # select doc_id, chunk_id from content_chunks_trsc where updated_at<NOW()- INTERVAL 1 DAY;
  181. # """
  182. # results = mysql_manager.fetch(query=query)
  183. # print(f"获取到{len(results)}条需要处理的记录")
  184. # print(results[0])
  185. #############################遍历milvus中的collection chunk_multi_embeddings_v2_trsc 找出所有doc_id,chunk_id 在results 中的记录
  186. res_id = milvus_client.query(
  187. expr="",
  188. output_fields=['id'],
  189. limit=10000,
  190. consistency_level="Strong"
  191. )
  192. print("res_id length is ", len(res_id)) # 6198
  193. total = len(res_id)
  194. # assert len(res_id) == total
  195. # exit()
  196. # batch_size=1000
  197. processed_count = 0
  198. for i in tqdm(range(0,total)):
  199. try:
  200. ####update selelct 数据
  201. # print("expr is ", f"id == {res_id[i]['id']}")
  202. # 从错误信息得知id字段是Int64类型,需要作为整数进行比较,不能用引号包裹
  203. results = milvus_client.query(
  204. expr=f"id == {res_id[i]['id']}", # 使用双等号,直接使用整数值
  205. limit=1,
  206. output_fields=['*'],
  207. consistency_level="Strong"
  208. )
  209. if not results:
  210. print(f"11111111########id {res_id[i]['id']} 没有查询到结果")
  211. continue
  212. ####找外健
  213. temp_doc_id = results[0]['doc_id']
  214. temp_chunk_id = results[0]['chunk_id']
  215. temp_vector_text = results[0]['vector_text']
  216. temp_vector_summary = results[0]['vector_summary']
  217. temp_vector_questions = results[0]['vector_questions']
  218. ####访问mysql,获取updated_at, 如果 updated_at<NOW()- INTERVAL 1 DAY 返回false,则已经处理过,pass,否则正常处理
  219. query = f"""
  220. select text,updated_at from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
  221. """
  222. results = mysql_manager.fetch_one(query=query)
  223. if not results:
  224. print(f"222222222#######doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到结果")
  225. continue
  226. temp_updated_at = results['updated_at']
  227. if temp_updated_at > datetime.now() - timedelta(days=1):
  228. print(f"doc_id {temp_doc_id} chunk_id {temp_chunk_id} 已经处理过,updated_at {temp_updated_at}")
  229. continue
  230. #############
  231. # exit()
  232. insert_data = []
  233. ####找mysql表text
  234. query = f"""
  235. select text from content_chunks_trsc where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
  236. """
  237. results = mysql_manager.fetch_one(query=query)
  238. if not results:
  239. print(f"3333333*******doc_id {temp_doc_id} chunk_id {temp_chunk_id} 没有查询到text结果")
  240. continue
  241. temp_text = results['text']
  242. # ####访问deepseek
  243. # """
  244. # 你是一个专业的文本转写助手,负责将用户输入的文本进行准确的转写。请确保转写结果与原始文本信息保持一致,避免添加任何额外的解释,注释。\
  245. # 请务必保持语言准确,精炼。以下是信息
  246. # """
  247. # ####grok
  248. # system_pre_prompt = f'''
  249. # 你是一个专业的文本精炼专家。请将以下原文改写成更简练、精确的版本。要求:
  250. # 1. 逐句检查原文,保留所有信息、事实、数据、关系和含义,不丢失任何细节。
  251. # 2. 去除冗余词语、重复表达和不必要的描述,使语言更紧凑和精准。
  252. # 3. 输出纯文本,不要添加任何格式化元素,如标题、列表、编号、粗体、换行符、额外解释或总结语句。只输出改写后的文本。
  253. # 4. 如果原文有专业术语,保持原样。
  254. # 原文:{temp_text}
  255. # '''
  256. # trans_answer_grok = get_deepseek_completion(
  257. # model="DeepSeek-V3", # 使用DeepSeek V3模型
  258. # prompt=system_pre_prompt, # 使用query_text_1作为prompt
  259. # output_type="text" # 返回文本格式
  260. # )
  261. # print("trans_answer grok is ", trans_answer_grok)
  262. print("text is", temp_text)
  263. ###chatgpt
  264. system_pre_prompt = f'''
  265. 请将以下文本转写为更简练、精确的表达方式。转写时需要满足以下要求:
  266. 不丢失信息:转写后的内容必须包含原文中的所有关键信息,任何核心的事实、细节和数据都不能丢失。
  267. 简洁清晰:去除冗余的词汇和不必要的修饰,确保语言简洁明了,信息传达精准。
  268. 无额外格式:转写后的文本应为纯文本,不添加任何额外的格式、列表、标点或其他符号,保持原有的结构。
  269. 请根据这些要求,转写以下段落:{temp_text}
  270. '''
  271. ###chatgpt
  272. trans_answer_gpt = get_deepseek_completion(
  273. model="DeepSeek-V3", # 使用DeepSeek V3模型
  274. prompt=system_pre_prompt, # 使用query_text_1作为prompt
  275. output_type="text" # 返回文本格式
  276. )
  277. # print("text is", temp_text)
  278. print("trans_answer_chatgpt is ", trans_answer_gpt)
  279. # ###compare
  280. # system_pre_prompt = f'''
  281. # 我有两个经过转写后的文本,请你根据以下标准来比较这两个文本,判断哪个文本更好:
  282. # 1. 信息完整性:哪个文本更好地保留了原文的关键信息?如果有任何重要信息缺失,请指出。
  283. # 2. 简洁性:哪个文本更加简洁明了,去除了冗余和不必要的部分?
  284. # 3. 清晰度:哪个文本更容易理解,表达更清晰?
  285. # 4. 精准度:哪个文本更准确地传达了原文的意思,没有歧义或误解?
  286. # 请基于这些标准分别对两个文本进行评估,并给出理由说明哪个文本更优。
  287. # 文本1: {trans_answer_grok}
  288. # 文本2: {trans_answer_gpt}
  289. # '''
  290. # ###compare
  291. # trans_answer_compare = get_deepseek_completion(
  292. # model="DeepSeek-V3", # 使用DeepSeek V3模型
  293. # prompt=system_pre_prompt, # 使用query_text_1作为prompt
  294. # output_type="text" # 返回文本格式
  295. # )
  296. # print("trans_answer_compare is ", trans_answer_compare)
  297. # exit()
  298. # ####访问embedding model 获取转写后的text
  299. temp_embedding = get_basic_embedding(text=temp_text, model=DEFAULT_MODEL)
  300. temp_embedding_trsc = get_basic_embedding(text=trans_answer_gpt, model=DEFAULT_MODEL)
  301. similarity_score = cosine_similarity(temp_embedding, temp_embedding_trsc)
  302. print(f"\n 重写后问题与重写前问题的余弦相似度: {cosine_similarity(temp_embedding, temp_embedding_trsc)}")
  303. # 更新mysql表,将转写后的文本、相似度分数和当前时间戳更新到对应记录中
  304. # test_query="select @@autocommit"
  305. # res = mysql_manager.fetch_one(test_query)
  306. # print("autocommit is ", res['@@autocommit'])
  307. # exit()
  308. update_query = f"""
  309. update content_chunks_trsc set simillarity = {similarity_score}, text = '{trans_answer_gpt}', updated_at = NOW() where doc_id = '{temp_doc_id}' and chunk_id = '{temp_chunk_id}';
  310. """
  311. mysql_manager.fetch_one(update_query)
  312. ####更新milvus 的collection,将转写后的embedding替代原来的embedding
  313. # 更新milvus的collection chunk_multi_embeddings_v2_trsc,使用转写过的text计算出的embedding temp_embedding_trsc 更新其vector_text
  314. # Milvus的upsert方法要求数据格式为字段名到值列表的映射
  315. # 注意:必须包含所有非空且无默认值的字段
  316. upsert_expr = f"id == {res_id[i]['id']}"
  317. # 正确格式:包含所有必需字段
  318. upsert_data = {
  319. "id": res_id[i]['id'], # 注意:upsert操作需要将值放入列表中
  320. "doc_id": temp_doc_id, # 添加必需的doc_id字段
  321. "chunk_id": temp_chunk_id, # 添加必需的chunk_id字段
  322. "vector_text": temp_embedding_trsc, # vector_text字段也需要放入列表中
  323. "vector_summary":temp_vector_summary,
  324. "vector_questions":temp_vector_questions
  325. }
  326. # 执行更新操作
  327. milvus_client.upsert(upsert_data, expr=upsert_expr) # 注意参数顺序:先数据后表达式
  328. processed_count += 1
  329. ##################
  330. if processed_count % 200 == 0:
  331. milvus_client.flush()
  332. except Exception as e:
  333. print(f"处理第{i}条数据时出现异常: {str(e)}")
  334. # 跳过当前循环,继续处理下一条数据
  335. continue
  336. milvus_client.flush()