db_manager.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from typing import Iterable
  2. from sqlalchemy import bindparam, create_engine, and_, or_, desc, text
  3. from sqlalchemy.orm import sessionmaker, Session
  4. class DatabaseManager:
  5. """数据库管理类"""
  6. # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
  7. def __init__(self):
  8. connection_string = (
  9. f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
  10. )
  11. self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
  12. self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
  13. def get_session(self) -> Session:
  14. """获取数据库会话"""
  15. return self.SessionLocal()
  16. class DatabaseManager2:
  17. """数据库管理类"""
  18. # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
  19. def __init__(self):
  20. connection_string = (
  21. f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
  22. )
  23. self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
  24. self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
  25. def get_session(self) -> Session:
  26. """获取数据库会话"""
  27. return self.SessionLocal()
  28. def query_video_ids_by_names(execution_id: int, names: Iterable[str]) -> list[str]:
  29. """按 execution_id + 名称列表查询去重后的 post_id。"""
  30. clean_names = [str(n).strip() for n in names if n is not None and str(n).strip()]
  31. if not clean_names:
  32. return []
  33. manager = DatabaseManager()
  34. session = manager.get_session()
  35. video_ids: set[str] = set()
  36. try:
  37. for name in clean_names:
  38. categories = session.execute(
  39. text(
  40. """
  41. SELECT id
  42. FROM topic_pattern_category
  43. WHERE execution_id = :execution_id AND name = :name
  44. """
  45. ),
  46. {"execution_id": execution_id, "name": name},
  47. ).fetchall()
  48. category_ids = [row[0] for row in categories if row and row[0] is not None]
  49. if category_ids:
  50. elements = session.execute(
  51. text(
  52. """
  53. SELECT post_id
  54. FROM topic_pattern_element
  55. WHERE execution_id = :execution_id
  56. AND category_id IN :category_ids
  57. """
  58. ).bindparams(bindparam("category_ids", expanding=True)),
  59. {"execution_id": execution_id, "category_ids": category_ids},
  60. ).fetchall()
  61. else:
  62. elements = session.execute(
  63. text(
  64. """
  65. SELECT post_id
  66. FROM topic_pattern_element
  67. WHERE execution_id = :execution_id AND name = :name
  68. """
  69. ),
  70. {"execution_id": execution_id, "name": name},
  71. ).fetchall()
  72. for row in elements:
  73. post_id = row[0] if row else None
  74. if post_id is not None and str(post_id).strip():
  75. video_ids.add(str(post_id).strip())
  76. finally:
  77. session.close()
  78. return list(video_ids)
  79. db2 = DatabaseManager2()
  80. def exist_cluster_tree(merge_level2):
  81. session = db2.get_session()
  82. exec_row = session.execute(
  83. text("""
  84. SELECT id
  85. FROM cluster_execution
  86. WHERE name LIKE :name AND status = 2
  87. ORDER BY create_time DESC
  88. LIMIT 1
  89. """),
  90. {"name": f"{merge_level2}%"},
  91. ).mappings().first()
  92. if not exec_row:
  93. return False
  94. return True