build_graph.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. """
  2. use neo4j to build graph
  3. """
  4. import json
  5. import random
  6. from dataclasses import fields
  7. from applications.utils.neo4j import AsyncNeo4jRepository
  8. from applications.utils.neo4j import Document, GraphChunk, ChunkRelations
  9. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  10. class BuildGraph(AsyncNeo4jRepository):
  11. def __init__(self, neo4j, mysql_client):
  12. super().__init__(neo4j)
  13. self.mysql_client = mysql_client
  14. @staticmethod
  15. def from_dict(cls, data: dict):
  16. field_names = {f.name for f in fields(cls)}
  17. return cls(**{k: v for k, v in data.items() if k in field_names})
  18. async def add_single_chunk(self, param):
  19. """async process single chunk"""
  20. param["milvus_id"] = random.randint(100000, 999999)
  21. doc: Document = self.from_dict(Document, param)
  22. graph_chunk: GraphChunk = self.from_dict(GraphChunk, param)
  23. relations: ChunkRelations = self.from_dict(ChunkRelations, param)
  24. await self.add_document_with_chunk(doc, graph_chunk, relations)
  25. async def get_chunk_list(self, doc_id):
  26. """async get chunk list"""
  27. query = """
  28. SELECT chunk_id, doc_id, topic, domain, task_type, keywords, concepts, entities, text_type, dataset_id
  29. FROM content_chunks
  30. WHERE embedding_status = %s AND status = %s and doc_id = %s;
  31. """
  32. response = await self.mysql_client.async_fetch(
  33. query=query,
  34. params=(2, 1, doc_id)
  35. )
  36. L = []
  37. for i in response:
  38. i["keywords"] = json.loads(i["keywords"])
  39. i["entities"] = json.loads(i["entities"])
  40. i["concepts"] = json.loads(i["concepts"])
  41. L.append(i)
  42. return L
  43. async def deal(self, doc_id):
  44. for task in await self.get_chunk_list(doc_id):
  45. await self.add_single_chunk(task)