build_graph.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """
  2. use neo4j to build graph
  3. """
  4. from dataclasses import fields
  5. from applications.utils.neo4j import AsyncNeo4jRepository
  6. from applications.utils.neo4j import Document, GraphChunk, ChunkRelations
  7. from applications.utils.mysql import ContentChunks
  8. class BuildGraph(AsyncNeo4jRepository):
  9. INIT_STATUS = 0
  10. PROCESSING_STATUS = 1
  11. FINISHED_STATUS = 2
  12. FAILED_STATUS = 3
  13. def __init__(self, neo4j, es_client, mysql_client):
  14. super().__init__(neo4j)
  15. self.es_client = es_client
  16. self.chunk_manager = ContentChunks(mysql_client)
  17. @staticmethod
  18. def from_dict(cls, data: dict):
  19. field_names = {f.name for f in fields(cls)}
  20. return cls(**{k: v for k, v in data.items() if k in field_names})
  21. async def add_single_chunk(self, param):
  22. """async process single chunk"""
  23. chunk_id = param["chunk_id"]
  24. doc_id = param["doc_id"]
  25. acquire_lock = await self.chunk_manager.update_graph_status(
  26. doc_id, chunk_id, self.INIT_STATUS, self.PROCESSING_STATUS
  27. )
  28. if not acquire_lock:
  29. print(f"while building graph, acquire lock for chunk {chunk_id}")
  30. return
  31. try:
  32. doc: Document = self.from_dict(Document, param)
  33. graph_chunk: GraphChunk = self.from_dict(GraphChunk, param)
  34. relations: ChunkRelations = self.from_dict(ChunkRelations, param)
  35. await self.add_document_with_chunk(doc, graph_chunk, relations)
  36. await self.chunk_manager.update_graph_status(
  37. doc_id, chunk_id, self.PROCESSING_STATUS, self.FINISHED_STATUS
  38. )
  39. except Exception as e:
  40. print(f"failed to build graph for chunk {chunk_id}: {e}")
  41. await self.chunk_manager.update_graph_status(
  42. doc_id, chunk_id, self.PROCESSING_STATUS, self.FAILED_STATUS
  43. )
  44. async def get_chunk_list_from_es(self, doc_id):
  45. """async get chunk list"""
  46. query = {
  47. "query": {"bool": {"must": [{"term": {"doc_id": doc_id}}]}},
  48. "_source": True,
  49. }
  50. try:
  51. resp = await self.es_client.async_search(query=query)
  52. return [hit["_source"] for hit in resp["hits"]["hits"]]
  53. except Exception as e:
  54. print(f"search failed: {e}")
  55. return []
  56. async def deal(self, doc_id):
  57. """async process single chunk"""
  58. chunk_list = await self.get_chunk_list_from_es(doc_id)
  59. for chunk in chunk_list:
  60. await self.add_single_chunk(chunk)
  61. async def deal_batch(self, dataset_id):
  62. """async task"""
  63. doc_ids = await self.chunk_manager.get_ungraphed_docs(dataset_id)
  64. for doc_id in doc_ids:
  65. try:
  66. await self.deal(doc_id)
  67. except Exception as e:
  68. print(f"failed to build graph for doc {doc_id}: {e}")