resource_manager.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from pymilvus import connections, CollectionSchema, Collection
  2. from neo4j import AsyncGraphDatabase, AsyncDriver
  3. from applications.config import NEO4j_CONFIG
  4. from applications.utils.milvus.what_field import what_fields
  5. from applications.utils.mysql import DatabaseManager
  6. from applications.utils.milvus.field import fields
  7. from applications.utils.elastic_search import AsyncElasticSearchClient
  8. class ResourceManager:
  9. def __init__(self, es_index, es_hosts, es_password, milvus_config):
  10. self.es_index = es_index
  11. self.es_hosts = es_hosts
  12. self.es_password = es_password
  13. self.milvus_config = milvus_config
  14. self.es_client: AsyncElasticSearchClient | None = None
  15. self.milvus_client: Collection | None = None
  16. self.what_milvus_client: Collection | None = None
  17. self.mysql_client: DatabaseManager | None = None
  18. self.graph_client: AsyncDriver | None = None
  19. async def load_milvus(self):
  20. connections.connect("default", **self.milvus_config)
  21. schema = CollectionSchema(
  22. fields, description="Chunk multi-vector embeddings with metadata"
  23. )
  24. self.milvus_client = Collection(name="chunk_multi_embeddings_v2", schema=schema)
  25. # create index
  26. vector_index_params = {
  27. "index_type": "IVF_FLAT",
  28. "metric_type": "COSINE",
  29. "params": {"M": 16, "efConstruction": 200},
  30. }
  31. self.milvus_client.create_index("vector_text", vector_index_params)
  32. self.milvus_client.create_index("vector_summary", vector_index_params)
  33. self.milvus_client.create_index("vector_questions", vector_index_params)
  34. self.milvus_client.load()
  35. what_schema = CollectionSchema(
  36. what_fields, description="What multi-vector embeddings with metadata"
  37. )
  38. self.what_milvus_client = Collection(name="what_multi_embeddings", schema=what_schema)
  39. self.what_milvus_client.create_index("vector_text", vector_index_params)
  40. async def startup(self):
  41. # 初始化 Elasticsearch
  42. self.es_client = AsyncElasticSearchClient(
  43. index_name=self.es_index, hosts=self.es_hosts, password=self.es_password
  44. )
  45. if await self.es_client.es.ping():
  46. print("✅ Elasticsearch connected")
  47. else:
  48. print("❌ Elasticsearch connection failed")
  49. # 初始化 MySQL
  50. self.mysql_client = DatabaseManager()
  51. await self.mysql_client.init_pools()
  52. print("✅ MySQL connected")
  53. # 初始化 milvus
  54. await self.load_milvus()
  55. print("✅ Milvus loaded")
  56. uri: str = NEO4j_CONFIG["url"]
  57. auth: tuple = NEO4j_CONFIG["user"], NEO4j_CONFIG["password"]
  58. self.graph_client = AsyncGraphDatabase.driver(uri=uri, auth=auth)
  59. print("✅ NEO4j loaded")
  60. async def shutdown(self):
  61. # 关闭 Elasticsearch
  62. if self.es_client:
  63. await self.es_client.close()
  64. print("Elasticsearch closed")
  65. # 关闭 Milvus
  66. connections.disconnect("default")
  67. print("Milvus closed")
  68. # 关闭 MySQL
  69. if self.mysql_client:
  70. await self.mysql_client.close_pools()
  71. print("Mysql closed")
  72. await self.graph_client.close()
  73. print("Graph closed")
  74. _resource_manager: ResourceManager | None = None
  75. def init_resource_manager(es_index, es_hosts, es_password, milvus_config):
  76. global _resource_manager
  77. if _resource_manager is None:
  78. _resource_manager = ResourceManager(
  79. es_index, es_hosts, es_password, milvus_config
  80. )
  81. return _resource_manager
  82. def get_resource_manager() -> ResourceManager:
  83. return _resource_manager