| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from typing import Iterable
- from sqlalchemy import bindparam, create_engine, and_, or_, desc, text
- from sqlalchemy.orm import sessionmaker, Session
- class DatabaseManager:
- """数据库管理类"""
- # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
- def __init__(self):
- connection_string = (
- f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
- )
- self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
- self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
- def get_session(self) -> Session:
- """获取数据库会话"""
- return self.SessionLocal()
- class DatabaseManager2:
- """数据库管理类"""
- # mysql+pymysql://<用户名>:<密码>@<主机地址>:<端口>/<数据库名>?charset=utf8mb4
- def __init__(self):
- connection_string = (
- f"mysql+pymysql://content_rw:bC1aH4bA1lB0@rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com:3306/open_aigc_pattern?charset=utf8mb4"
- )
- self.engine = create_engine(connection_string, pool_pre_ping=True, pool_recycle=3600)
- self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
- def get_session(self) -> Session:
- """获取数据库会话"""
- return self.SessionLocal()
- def query_video_ids_by_names(execution_id: int, names: Iterable[str]) -> list[str]:
- """按 execution_id + 名称列表查询去重后的 post_id。"""
- clean_names = [str(n).strip() for n in names if n is not None and str(n).strip()]
- if not clean_names:
- return []
- manager = DatabaseManager()
- session = manager.get_session()
- video_ids: set[str] = set()
- try:
- for name in clean_names:
- categories = session.execute(
- text(
- """
- SELECT id
- FROM topic_pattern_category
- WHERE execution_id = :execution_id AND name = :name
- """
- ),
- {"execution_id": execution_id, "name": name},
- ).fetchall()
- category_ids = [row[0] for row in categories if row and row[0] is not None]
- if category_ids:
- elements = session.execute(
- text(
- """
- SELECT post_id
- FROM topic_pattern_element
- WHERE execution_id = :execution_id
- AND category_id IN :category_ids
- """
- ).bindparams(bindparam("category_ids", expanding=True)),
- {"execution_id": execution_id, "category_ids": category_ids},
- ).fetchall()
- else:
- elements = session.execute(
- text(
- """
- SELECT post_id
- FROM topic_pattern_element
- WHERE execution_id = :execution_id AND name = :name
- """
- ),
- {"execution_id": execution_id, "name": name},
- ).fetchall()
- for row in elements:
- post_id = row[0] if row else None
- if post_id is not None and str(post_id).strip():
- video_ids.add(str(post_id).strip())
- finally:
- session.close()
- return list(video_ids)
- db2 = DatabaseManager2()
- def exist_cluster_tree(merge_level2):
- session = db2.get_session()
- exec_row = session.execute(
- text("""
- SELECT id
- FROM cluster_execution
- WHERE name LIKE :name AND status = 2
- ORDER BY create_time DESC
- LIMIT 1
- """),
- {"name": f"{merge_level2}%"},
- ).mappings().first()
- if not exec_row:
- return False
- return True
|