db_manager.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Iterable
  2. from sqlalchemy import bindparam, create_engine, and_, or_, desc, text
  3. from sqlalchemy.orm import sessionmaker, Session
  4. class DatabaseManager:
  5. """数据库管理类"""
  6. # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
  7. def __init__(self):
  8. connection_string = (
  9. f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
  10. )
  11. self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
  12. self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
  13. def get_session(self) -> Session:
  14. """获取数据库会话"""
  15. return self.SessionLocal()
  16. def query_video_ids_by_names(execution_id: int, names: Iterable[str]) -> list[str]:
  17. """按 execution_id + 名称列表查询去重后的 post_id。"""
  18. clean_names = [str(n).strip() for n in names if n is not None and str(n).strip()]
  19. if not clean_names:
  20. return []
  21. manager = DatabaseManager()
  22. session = manager.get_session()
  23. video_ids: set[str] = set()
  24. try:
  25. for name in clean_names:
  26. categories = session.execute(
  27. text(
  28. """
  29. SELECT id
  30. FROM topic_pattern_category
  31. WHERE execution_id = :execution_id AND name = :name
  32. """
  33. ),
  34. {"execution_id": execution_id, "name": name},
  35. ).fetchall()
  36. category_ids = [row[0] for row in categories if row and row[0] is not None]
  37. if category_ids:
  38. elements = session.execute(
  39. text(
  40. """
  41. SELECT post_id
  42. FROM topic_pattern_element
  43. WHERE execution_id = :execution_id
  44. AND category_id IN :category_ids
  45. """
  46. ).bindparams(bindparam("category_ids", expanding=True)),
  47. {"execution_id": execution_id, "category_ids": category_ids},
  48. ).fetchall()
  49. else:
  50. elements = session.execute(
  51. text(
  52. """
  53. SELECT post_id
  54. FROM topic_pattern_element
  55. WHERE execution_id = :execution_id AND name = :name
  56. """
  57. ),
  58. {"execution_id": execution_id, "name": name},
  59. ).fetchall()
  60. for row in elements:
  61. post_id = row[0] if row else None
  62. if post_id is not None and str(post_id).strip():
  63. video_ids.add(str(post_id).strip())
  64. finally:
  65. session.close()
  66. return list(video_ids)