db_manager.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from typing import Iterable
  2. from sqlalchemy import bindparam, create_engine, and_, or_, desc, text
  3. from sqlalchemy.orm import sessionmaker, Session
  4. class DatabaseManager:
  5. """数据库管理类"""
  6. # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
  7. def __init__(self):
  8. connection_string = (
  9. f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
  10. )
  11. self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
  12. self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
  13. def get_session(self) -> Session:
  14. """获取数据库会话"""
  15. return self.SessionLocal()
  16. def query_video_ids_by_names(execution_id: int, names: Iterable[str]) -> list[str]:
  17. """按 execution_id + 名称列表查询去重后的 post_id。"""
  18. clean_names = [str(n).strip() for n in names if n is not None and str(n).strip()]
  19. if not clean_names:
  20. return []
  21. manager = DatabaseManager()
  22. session = manager.get_session()
  23. video_ids: set[str] = set()
  24. try:
  25. for name in clean_names:
  26. categories = session.execute(
  27. text(
  28. """
  29. SELECT id
  30. FROM topic_pattern_category
  31. WHERE execution_id = :execution_id AND name = :name
  32. """
  33. ),
  34. {"execution_id": execution_id, "name": name},
  35. ).fetchall()
  36. category_ids = [row[0] for row in categories if row and row[0] is not None]
  37. if category_ids:
  38. elements = session.execute(
  39. text(
  40. """
  41. SELECT post_id
  42. FROM topic_pattern_element
  43. WHERE execution_id = :execution_id
  44. AND category_id IN :category_ids
  45. """
  46. ).bindparams(bindparam("category_ids", expanding=True)),
  47. {"execution_id": execution_id, "category_ids": category_ids},
  48. ).fetchall()
  49. else:
  50. elements = session.execute(
  51. text(
  52. """
  53. SELECT post_id
  54. FROM topic_pattern_element
  55. WHERE execution_id = :execution_id AND name = :name
  56. """
  57. ),
  58. {"execution_id": execution_id, "name": name},
  59. ).fetchall()
  60. for row in elements:
  61. post_id = row[0] if row else None
  62. if post_id is not None and str(post_id).strip():
  63. video_ids.add(str(post_id).strip())
  64. finally:
  65. session.close()
  66. return list(video_ids)
  67. def query_category_level(execution_id: int, name: str) -> int | None:
  68. """按 execution_id + 分类名称查询 topic_pattern_category.level。"""
  69. clean_name = str(name).strip() if name is not None else ""
  70. if not clean_name:
  71. return None
  72. manager = DatabaseManager()
  73. session = manager.get_session()
  74. try:
  75. row = session.execute(
  76. text(
  77. """
  78. SELECT level
  79. FROM topic_pattern_category
  80. WHERE execution_id = :execution_id AND name = :name
  81. ORDER BY id DESC
  82. LIMIT 1
  83. """
  84. ),
  85. {"execution_id": execution_id, "name": clean_name},
  86. ).first()
  87. if not row:
  88. return None
  89. level = row[0]
  90. return int(level) if level is not None else None
  91. finally:
  92. session.close()
  93. db = DatabaseManager()
  94. def exist_cluster_tree(merge_level2):
  95. session = db.get_session()
  96. exec_row = session.execute(
  97. text("""
  98. SELECT id
  99. FROM cluster_execution
  100. WHERE name LIKE :name AND status = 2
  101. ORDER BY create_time DESC
  102. LIMIT 1
  103. """),
  104. {"name": f"{merge_level2}%"},
  105. ).mappings().first()
  106. if not exec_row:
  107. return False
  108. return True