build_graph.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """
  2. use neo4j to build graph
  3. """
  4. from dataclasses import fields
  5. from applications.utils.neo4j import AsyncNeo4jRepository
  6. from applications.utils.neo4j import Document, GraphChunk, ChunkRelations
  7. from applications.utils.mysql import ContentChunks
  8. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  9. class BuildGraph(AsyncNeo4jRepository):
  10. INIT_STATUS = 0
  11. PROCESSING_STATUS = 1
  12. FINISHED_STATUS = 2
  13. FAILED_STATUS = 3
  14. def __init__(self, neo4j, es_client, mysql_client):
  15. super().__init__(neo4j)
  16. self.es_client = es_client
  17. self.chunk_manager = ContentChunks(mysql_client)
  18. @staticmethod
  19. def from_dict(cls, data: dict):
  20. field_names = {f.name for f in fields(cls)}
  21. return cls(**{k: v for k, v in data.items() if k in field_names})
  22. async def add_single_chunk(self, param):
  23. """async process single chunk"""
  24. chunk_id = param["chunk_id"]
  25. doc_id = param["doc_id"]
  26. acquire_lock = await self.chunk_manager.update_graph_status(
  27. doc_id, chunk_id, self.INIT_STATUS, self.PROCESSING_STATUS
  28. )
  29. if acquire_lock:
  30. print(f"while building graph, acquire lock for chunk {chunk_id}")
  31. return
  32. try:
  33. doc: Document = self.from_dict(Document, param)
  34. graph_chunk: GraphChunk = self.from_dict(GraphChunk, param)
  35. relations: ChunkRelations = self.from_dict(ChunkRelations, param)
  36. await self.add_document_with_chunk(doc, graph_chunk, relations)
  37. await self.chunk_manager.update_graph_status(
  38. doc_id, chunk_id, self.PROCESSING_STATUS, self.FINISHED_STATUS
  39. )
  40. except Exception as e:
  41. print(f"failed to build graph for chunk {chunk_id}: {e}")
  42. await self.chunk_manager.update_graph_status(
  43. doc_id, chunk_id, self.PROCESSING_STATUS, self.FAILED_STATUS
  44. )
  45. async def get_chunk_list_from_es(self, doc_id):
  46. """async get chunk list"""
  47. query = {
  48. "query": {"bool": {"must": [{"term": {"doc_id": doc_id}}]}},
  49. "_source": True,
  50. }
  51. try:
  52. resp = await self.es_client.async_search(query=query)
  53. return [hit["_source"] for hit in resp["hits"]["hits"]]
  54. except Exception as e:
  55. print(f"search failed: {e}")
  56. return []
  57. async def deal(self, doc_id):
  58. """async process single chunk"""
  59. chunk_list = await self.get_chunk_list_from_es(doc_id)
  60. await run_tasks_with_asyncio_task_group(
  61. task_list=chunk_list,
  62. handler=self.add_single_chunk,
  63. description="build graph",
  64. unit="chunk",
  65. max_concurrency=10,
  66. )