milvus_pattern_insert.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from pymilvus import Collection, connections, utility, FieldSchema, CollectionSchema, DataType
  2. import requests
  3. import json
  4. from typing import Dict, Any, List
  5. from pymongo import MongoClient
  6. from pydub import AudioSegment
  7. import io
  8. from scipy.io import wavfile
  9. import numpy as np
  10. ################################连接milvus数据库 A
  11. # 配置信息
  12. MILVUS_CONFIG = {
  13. "host": "c-981be0ee7225467b-internal.milvus.aliyuncs.com",
  14. "user": "root",
  15. "password": "Piaoquan@2025",
  16. "port": "19530",
  17. }
  18. print("正在连接 Milvus 数据库...")
  19. connections.connect("default", **MILVUS_CONFIG)
  20. print("连接成功!")
  21. ################################连接milvus数据库 B
  22. ################################连接Embedding service A
  23. # 注意:根据之前的讨论,需要通过SSH隧道将远程服务转发到本地
  24. # 在本地机器上执行: ssh -R 8000:192.168.100.31:8000 username@server_ip
  25. VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
  26. DEFAULT_MODEL = "/models/Qwen3-Embedding-4B"
  27. def get_basic_embedding(text: str, model=DEFAULT_MODEL):
  28. """通过HTTP调用在线embedding服务"""
  29. headers = {
  30. "Content-Type": "application/json"
  31. }
  32. data = {
  33. "model": model,
  34. "input": text
  35. }
  36. response = requests.post(
  37. VLLM_SERVER_URL,
  38. headers=headers,
  39. json=data,
  40. timeout=5 # 添加超时设置
  41. )
  42. response.raise_for_status() # 如果状态码不是200,抛出异常
  43. result = response.json()
  44. return result["data"][0]["embedding"]
  45. def parse_pattern_res(json_data) -> Dict[str, Dict[str, str]]:
  46. """
  47. 解析 pattern_res.json 文件,提取两类信息:
  48. 1. 所有 "模式ID","模式命名","模式说明" 字段的 path 与 value 映射
  49. 返回:
  50. {
  51. "模式ID": {path: value, ...},
  52. "模式命名": {path: value, ...},
  53. "模式说明": {path: value, ...}
  54. }
  55. """
  56. data = json_data
  57. pattern_dict: Dict[str, Any] = {}
  58. def traverse(obj: Any, current_path: str = ""):
  59. """递归遍历 JSON 结构,记录目标字段"""
  60. if isinstance(obj, dict):
  61. for k, v in obj.items():
  62. # 构建新路径,避免在开头添加点号
  63. new_path = f"{current_path}.{k}" if current_path else k
  64. if k == "模式ID":
  65. # 当遇到“模式ID”时,同时获取同层的“模式命名”和“模式描述”
  66. temp_dict ={}
  67. temp_dict["模式ID"] = v
  68. temp_dict["模式命名"] = obj.get("模式命名", "")
  69. temp_dict["模式说明"] = obj.get("模式说明", "")
  70. pattern_dict[current_path] = temp_dict
  71. traverse(v, new_path)
  72. elif isinstance(obj, list):
  73. for idx, item in enumerate(obj):
  74. # 对于数组元素,使用方括号索引
  75. new_path = f"{current_path}[{idx}]"
  76. traverse(item, new_path)
  77. traverse(data)
  78. return {"pattern": pattern_dict}
  79. # 使用示例
  80. if __name__ == "__main__":
  81. # 连接 MongoDB 数据库
  82. ##################### 存储到mongoDB
  83. MONGO_URI = "mongodb://localhost:27017/"
  84. DB_NAME = "mydeconstruct"
  85. COLL_NAME = "deconstruct_how"
  86. client = MongoClient(MONGO_URI)
  87. db = client[DB_NAME]
  88. coll = db[COLL_NAME]
  89. # 读取并插入 JSON 文件
  90. json_path = "/home/ecs-user/project/colpali/src/pattern_res.json"
  91. with open(json_path, "r", encoding="utf-8") as f:
  92. doc = json.load(f)
  93. result = parse_pattern_res(doc)
  94. for key, value in result["pattern"].items():
  95. print(f"pattern 字段 {key} 的值为: {value}")
  96. # exit()
  97. insert_result = coll.insert_one(doc)
  98. inserted_id = insert_result.inserted_id
  99. ##################### 将 result["how"] 中的每个 value 转换为向量并插入 Milvus
  100. ########## 文本向量库存一份how
  101. # 创建 Milvus 集合(如不存在)
  102. collection_name = "deconstruct_pattern"
  103. if not utility.has_collection(collection_name):
  104. # utility.drop_collection(collection_name)
  105. fields = [
  106. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  107. FieldSchema(name="mongo_id", dtype=DataType.VARCHAR, max_length=64),
  108. FieldSchema(name="pattern_id", dtype=DataType.VARCHAR, max_length=64),
  109. FieldSchema(name="pattern_name", dtype=DataType.VARCHAR, max_length=128),
  110. FieldSchema(name="pattern_desc", dtype=DataType.VARCHAR, max_length=2048),
  111. FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512),
  112. FieldSchema(name="name_embedding", dtype=DataType.FLOAT_VECTOR, dim=2560),
  113. FieldSchema(name="desc_embedding", dtype=DataType.FLOAT_VECTOR, dim=2560)
  114. ]
  115. schema = CollectionSchema(fields, description="Deconstruct how embeddings")
  116. collection = Collection(name=collection_name, schema=schema)
  117. # 创建 IVF_FLAT 索引
  118. index_params = {
  119. "metric_type": "IP",
  120. "index_type": "IVF_FLAT",
  121. "params": {"nlist": 128}
  122. }
  123. # 为 pattern_id 字段创建字符串索引
  124. collection.create_index("pattern_id", {
  125. "index_type": "INVERTED" #"Trie"
  126. })
  127. collection.create_index("name_embedding", index_params)
  128. collection.create_index("desc_embedding", index_params)
  129. else:
  130. collection = Collection(name=collection_name)
  131. entities = []
  132. for key, value in result["pattern"].items():
  133. pattern_id = value["模式ID"]
  134. pattern_name = value["模式命名"]
  135. pattern_desc = value["模式说明"]
  136. ### 访问可达则替换
  137. # name_embedding = get_basic_embedding(pattern_name, model=DEFAULT_MODEL)
  138. # desc_embedding = get_basic_embedding(pattern_desc, model=DEFAULT_MODEL)
  139. ###
  140. name_embedding = np.random.rand(2560).tolist()
  141. desc_embedding = np.random.rand(2560).tolist()
  142. path = key
  143. entities.append({
  144. "mongo_id": str(inserted_id),
  145. "pattern_id": pattern_id,
  146. "pattern_name": pattern_name,
  147. "pattern_desc": pattern_desc,
  148. "path": path,
  149. "name_embedding": name_embedding,
  150. "desc_embedding": desc_embedding
  151. })
  152. # 遍历 result["pattern"],生成 embeddings 并插入 Milvus
  153. # print("entities is ", entities)
  154. if entities:
  155. collection.insert(entities)
  156. collection.flush()
  157. print(f"已插入 {len(entities)} 条 how 字段向量到 Milvus")
  158. else:
  159. print("未找到 how 字段,未插入向量")