milvus_deconstruct_insert.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. from pymilvus import Collection, connections, utility, FieldSchema, CollectionSchema, DataType
  2. import requests
  3. import json
  4. from typing import Dict, Any, List
  5. from pymongo import MongoClient
  6. from pydub import AudioSegment
  7. import io
  8. from scipy.io import wavfile
  9. ################################连接milvus数据库 A
  10. # 配置信息
  11. MILVUS_CONFIG = {
  12. "host": "c-981be0ee7225467b-internal.milvus.aliyuncs.com",
  13. "user": "root",
  14. "password": "Piaoquan@2025",
  15. "port": "19530",
  16. }
  17. print("正在连接 Milvus 数据库...")
  18. connections.connect("default", **MILVUS_CONFIG)
  19. print("连接成功!")
  20. ################################连接milvus数据库 B
  21. ##################################引入多模态模型#################
  22. import torch
  23. from PIL import Image
  24. from transformers.utils.import_utils import is_flash_attn_2_available
  25. from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
  26. model = ColQwen2_5Omni.from_pretrained(
  27. "vidore/colqwen-omni-v0.1",
  28. torch_dtype=torch.bfloat16,
  29. device_map="cuda", # or "mps" if on Apple Silicon
  30. attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
  31. ).eval()
  32. processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1")
  33. ##################################引入模型#################
  34. ################################连接Embedding service A
  35. # 注意:根据之前的讨论,需要通过SSH隧道将远程服务转发到本地
  36. # 在本地机器上执行: ssh -R 8000:192.168.100.31:8000 username@server_ip
  37. VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
  38. DEFAULT_MODEL = "/models/Qwen3-Embedding-4B"
  39. def get_basic_embedding(text: str, model=DEFAULT_MODEL):
  40. """通过HTTP调用在线embedding服务"""
  41. headers = {
  42. "Content-Type": "application/json"
  43. }
  44. data = {
  45. "model": model,
  46. "input": text
  47. }
  48. response = requests.post(
  49. VLLM_SERVER_URL,
  50. headers=headers,
  51. json=data,
  52. timeout=5 # 添加超时设置
  53. )
  54. response.raise_for_status() # 如果状态码不是200,抛出异常
  55. result = response.json()
  56. return result["data"][0]["embedding"]
  57. def get_media_embedding(query: str, type: str):
  58. '''
  59. query 是查询字符串或文件路径
  60. type 是查询类型,可选值为 "audio", "image", "video", "text"
  61. k 是返回的结果数量,默认值为 3
  62. audio image video 的query为路径
  63. text的query为问题本身
  64. '''
  65. if type =="audio":
  66. batch_queries = processor.process_audios([query]).to(model.device)
  67. elif type =="image":
  68. query_image = Image.open(query)
  69. batch_queries = processor.process_images([query_image]).to(model.device)
  70. elif type =="video":
  71. batch_queries = processor.process_videos([query]).to(model.device)
  72. elif type =="text":
  73. batch_queries = processor.process_queries([query]).to(model.device)
  74. # Forward pass
  75. with torch.no_grad():
  76. query_embeddings = model(**batch_queries)
  77. return query_embeddings
  78. # # scores = processor.score_multi_vector(query_embeddings, ds)
  79. # print("score is ", scores)
  80. # # get top-5 scores
  81. # return scores[0].topk(k).indices.tolist()
  82. # ################################连接Embedding service B
  83. def parse_deconstruct_res(json_data) -> Dict[str, Dict[str, str]]:
  84. """
  85. 解析 deconstruct_res.json 文件,提取两类信息:
  86. 1. 所有 "what" 字段的 path 与 value 映射
  87. 2. 所有类型为 "image" 或 "video" 的媒体引用 path 与 content 值映射
  88. 返回:
  89. {
  90. "what": {path: value, ...},
  91. "media": {path: value, ...}
  92. }
  93. """
  94. data = json_data
  95. what_dict: Dict[str, Any] = {}
  96. media_dict: Dict[str, Any] = {}
  97. def traverse(obj: Any, current_path: str = ""):
  98. """递归遍历 JSON 结构,记录目标字段"""
  99. if isinstance(obj, dict):
  100. for k, v in obj.items():
  101. # 构建新路径,避免在开头添加点号
  102. new_path = f"{current_path}.{k}" if current_path else k
  103. if k == "what":
  104. what_dict[new_path] = v
  105. # 处理媒体引用字段
  106. elif k == "媒体引用" and isinstance(v, list):
  107. # 遍历媒体引用数组
  108. for idx, media_item in enumerate(v):
  109. if isinstance(media_item, dict) and media_item.get("type") in ("image", "video", "audio"):
  110. # 记录content字段作为媒体路径
  111. content = media_item.get("content")
  112. type_nm = media_item.get("type")
  113. if content:
  114. # 生成正确格式的路径,如"图片元素[5].媒体引用[0].content"
  115. media_ref_path = f"{type_nm}-{new_path}[{idx}].content"
  116. media_dict[media_ref_path] = content
  117. # 继续递归遍历
  118. traverse(v, new_path)
  119. elif isinstance(obj, list):
  120. for idx, item in enumerate(obj):
  121. # 对于数组元素,使用方括号索引
  122. new_path = f"{current_path}[{idx}]"
  123. traverse(item, new_path)
  124. traverse(data)
  125. return {"what": what_dict, "media": media_dict}
  126. # 使用示例
  127. if __name__ == "__main__":
  128. # 连接 MongoDB 数据库
  129. ##################### 存储到mongoDB
  130. MONGO_URI = "mongodb://localhost:27017/"
  131. DB_NAME = "mydeconstruct"
  132. COLL_NAME = "deconstruct"
  133. client = MongoClient(MONGO_URI)
  134. db = client[DB_NAME]
  135. coll = db[COLL_NAME]
  136. # 读取并插入 JSON 文件
  137. json_path = "/home/ecs-user/project/colpali/src/deconstruct_res.json"
  138. with open(json_path, "r", encoding="utf-8") as f:
  139. doc = json.load(f)
  140. insert_result = coll.insert_one(doc)
  141. inserted_id = insert_result.inserted_id
  142. print("已插入 MongoDB,文档 _id:", inserted_id)
  143. result = parse_deconstruct_res(doc)
  144. print("what 字段映射:", result["what"])
  145. print("媒体引用映射:", result["media"])
  146. ##################### 存储到mongoDB
  147. ##################### 将 result["what"] 中的每个 value 转换为向量并插入 Milvus
  148. ########## 文本向量库存一份what
  149. # 创建 Milvus 集合(如不存在)
  150. collection_name = "deconstruct_what"
  151. if not utility.has_collection(collection_name):
  152. fields = [
  153. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  154. FieldSchema(name="mongo_id", dtype=DataType.VARCHAR, max_length=64),
  155. FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512),
  156. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560)
  157. ]
  158. schema = CollectionSchema(fields, description="Deconstruct what embeddings")
  159. collection = Collection(name=collection_name, schema=schema)
  160. # 创建 IVF_FLAT 索引
  161. index_params = {
  162. "metric_type": "IP",
  163. "index_type": "IVF_FLAT",
  164. "params": {"nlist": 128}
  165. }
  166. collection.create_index("embedding", index_params)
  167. else:
  168. collection = Collection(name=collection_name)
  169. # 遍历 result["what"],生成 embeddings 并插入 Milvus
  170. entities = []
  171. for key, value in result["what"].items():
  172. embedding = get_basic_embedding(value, model=DEFAULT_MODEL)
  173. path = key
  174. entities.append({
  175. "mongo_id": str(inserted_id),
  176. "path": path,
  177. "embedding": embedding
  178. })
  179. if entities:
  180. collection.insert(entities)
  181. collection.flush()
  182. print(f"已插入 {len(entities)} 条 what 字段向量到 Milvus")
  183. else:
  184. print("未找到 what 字段,未插入向量")
  185. ##################### 将 result["what"] 中的每个 value 转换为向量并插入 Milvus
  186. #####################将 result["media"] 中的每个 value 调用多模态编码模型计算embedding并插入Milvus
  187. # 创建 Milvus 集合(如不存在)
  188. collection_name = "deconstruct_media"
  189. if not utility.has_collection(collection_name):
  190. fields = [
  191. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  192. FieldSchema(name="mongo_id", dtype=DataType.VARCHAR, max_length=64),
  193. FieldSchema(name="type", dtype=DataType.VARCHAR, max_length=64),
  194. FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512),
  195. FieldSchema(name="no", dtype=DataType.INT32),
  196. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560)
  197. ]
  198. schema = CollectionSchema(fields, description="Deconstruct media embeddings")
  199. collection = Collection(name=collection_name, schema=schema)
  200. # 创建 IVF_FLAT 索引
  201. index_params = {
  202. "metric_type": "IP",
  203. "index_type": "IVF_FLAT",
  204. "params": {"nlist": 128}
  205. }
  206. collection.create_index("embedding", index_params)
  207. else:
  208. collection = Collection(name=collection_name)
  209. # 遍历 result["media"],生成 embeddings 并插入 Milvus
  210. #############存储一份media embedding到Milvus
  211. entities = []
  212. for key, value in result["media"].items():
  213. embedding = get_media_embedding(value, model=DEFAULT_MODEL)
  214. type = key[:key.index("-")]
  215. path = key[key.index("-"):]
  216. # 将 embedding 列表拆分为单条向量,并记录其在原列表中的位置 no
  217. if isinstance(embedding, list) and len(embedding) > 0:
  218. for idx, vec in enumerate(embedding):
  219. entities.append({
  220. "mongo_id": str(inserted_id),
  221. "type": type,
  222. "path": path,
  223. "no": idx,
  224. "embedding": vec
  225. })
  226. else:
  227. # 若 embedding 不是列表或长度为 0,则 no 记为 0
  228. entities.append({
  229. "mongo_id": str(inserted_id),
  230. "type": type,
  231. "path": path,
  232. "no": 0,
  233. "embedding": embedding
  234. })
  235. # 将插入操作移到循环外部,避免重复插入和数据累积
  236. if entities:
  237. collection.insert(entities)
  238. collection.flush()
  239. print(f"已插入 {len(entities)} 条 media 字段向量到 Milvus")
  240. else:
  241. print("未找到有效的 media 字段向量,未插入数据")
  242. #############存储一份what 多模态embedding 到Milvus
  243. entities = []
  244. for key, value in result["what"].items():
  245. embedding = get_media_embedding(value, model=DEFAULT_MODEL)
  246. # type = key[:key.index("-")]
  247. # path = key[key.index("-"):]
  248. path = key
  249. if isinstance(embedding, list) and len(embedding) > 0:
  250. for idx, vec in enumerate(embedding):
  251. entities.append({
  252. "mongo_id": str(inserted_id),
  253. "type": "text",
  254. "path": path,
  255. "no": idx,
  256. "embedding": vec
  257. })
  258. else:
  259. # 若 embedding 不是列表或长度为 0,则 no 记为 0
  260. entities.append({
  261. "mongo_id": str(inserted_id),
  262. "type": "text",
  263. "path": path,
  264. "no": 0,
  265. "embedding": embedding
  266. })
  267. # 将插入操作移到循环外部,避免重复插入和数据累积
  268. if entities:
  269. collection.insert(entities)
  270. collection.flush()
  271. print(f"已插入 {len(entities)} 条 what 多模态向量到 Milvus")
  272. else:
  273. print("未找到有效的 what 多模态向量,未插入数据")
  274. #############存储一份what 多模态embedding 到Milvus