resource_manager.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from pymilvus import connections, CollectionSchema, Collection
  2. from neo4j import AsyncGraphDatabase, AsyncDriver
  3. from applications.config import NEO4j_CONFIG
  4. from applications.utils.mysql import DatabaseManager
  5. from applications.utils.milvus.field import fields, mode_fields
  6. from applications.utils.elastic_search import AsyncElasticSearchClient
  7. class ResourceManager:
  8. def __init__(self, es_index, es_hosts, es_password, milvus_config):
  9. self.es_index = es_index
  10. self.es_hosts = es_hosts
  11. self.es_password = es_password
  12. self.milvus_config = milvus_config
  13. self.es_client: AsyncElasticSearchClient | None = None
  14. self.milvus_client: Collection | None = None
  15. self.mysql_client: DatabaseManager | None = None
  16. self.graph_client: AsyncDriver | None = None
  17. async def load_milvus(self):
  18. connections.connect("default", **self.milvus_config)
  19. schema = CollectionSchema(
  20. mode_fields, description="标准模式向量空间"
  21. )
  22. self.milvus_client = Collection(name="standard_mode_embeddings", schema=schema)
  23. # create index
  24. vector_index_params = {
  25. "index_type": "IVF_FLAT",
  26. "metric_type": "COSINE",
  27. "params": {"M": 16, "efConstruction": 200},
  28. }
  29. self.milvus_client.create_index("mode_vector", vector_index_params)
  30. self.milvus_client.load()
  31. async def startup(self):
  32. # 初始化 Elasticsearch
  33. # self.es_client = AsyncElasticSearchClient(
  34. # index_name=self.es_index, hosts=self.es_hosts, password=self.es_password
  35. # )
  36. # if await self.es_client.es.ping():
  37. # print("✅ Elasticsearch connected")
  38. # else:
  39. # print("❌ Elasticsearch connection failed")
  40. # 初始化 MySQL
  41. self.mysql_client = DatabaseManager()
  42. await self.mysql_client.init_pools()
  43. print("✅ MySQL connected")
  44. # 初始化 milvus
  45. await self.load_milvus()
  46. print("✅ Milvus loaded")
  47. # uri: str = NEO4j_CONFIG["url"]
  48. # auth: tuple = NEO4j_CONFIG["user"], NEO4j_CONFIG["password"]
  49. # self.graph_client = AsyncGraphDatabase.driver(uri=uri, auth=auth)
  50. # print("✅ NEO4j loaded")
  51. async def shutdown(self):
  52. # 关闭 Elasticsearch
  53. if self.es_client:
  54. await self.es_client.close()
  55. print("Elasticsearch closed")
  56. # 关闭 Milvus
  57. connections.disconnect("default")
  58. print("Milvus closed")
  59. # 关闭 MySQL
  60. if self.mysql_client:
  61. await self.mysql_client.close_pools()
  62. print("Mysql closed")
  63. # await self.graph_client.close()
  64. # print("Graph closed")
  65. _resource_manager: ResourceManager | None = None
  66. def init_resource_manager(es_index, es_hosts, es_password, milvus_config):
  67. global _resource_manager
  68. if _resource_manager is None:
  69. _resource_manager = ResourceManager(
  70. es_index, es_hosts, es_password, milvus_config
  71. )
  72. return _resource_manager
  73. def get_resource_manager() -> ResourceManager:
  74. return _resource_manager