| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- """垂直领域分类树:查询分类/元素基础信息与效果得分。"""
- import re
- from sqlalchemy import text
- from app.core.config import settings
- from app.db.mysql import SessionLocal
- IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
- DATE_RE = re.compile(r"^\d{8}$")
- def _safe_identifier(name: str) -> str:
- if not IDENTIFIER_RE.match(name):
- raise ValueError(f"invalid sql identifier: {name}")
- return name
- def _normalize_date(date_value: str | None) -> str | None:
- if not date_value:
- return None
- normalized = date_value.replace("-", "").strip()
- if not normalized:
- return None
- if not DATE_RE.match(normalized):
- raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
- return normalized
- def _resolve_partition_dt(dt: str | None) -> str:
- normalized = _normalize_date(dt)
- if normalized:
- return normalized
- category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
- with SessionLocal() as session:
- row = session.execute(
- text(f"SELECT MAX(dt) FROM {category_effect_table}")
- ).first()
- latest = str(row[0] or "").strip() if row else ""
- if not latest:
- raise ValueError("暂无效果数据分区,请先同步垂直领域分类数据")
- return latest
- def _is_zero_rov_score(rov_score: object) -> bool:
- if rov_score is None:
- return True
- return float(rov_score) == 0.0
- def _build_children_map(categories: list[dict[str, object]]) -> dict[str, list[dict[str, object]]]:
- children_map: dict[str, list[dict[str, object]]] = {}
- for category in categories:
- parent_id = category.get("parent_stable_id")
- if not parent_id:
- continue
- parent_key = str(parent_id)
- children_map.setdefault(parent_key, []).append(category)
- for siblings in children_map.values():
- siblings.sort(key=lambda item: str(item["category_id"]))
- return children_map
- def _is_missing_parent(category: dict[str, object], all_ids: set[str]) -> bool:
- parent_id = category.get("parent_stable_id")
- if not parent_id:
- return True
- return str(parent_id) not in all_ids
- def _find_root_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
- all_ids = {str(item["category_id"]) for item in categories}
- roots = [
- item
- for item in categories
- if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) <= 1
- ]
- roots.sort(key=lambda item: str(item["category_id"]))
- return roots
- def _find_detached_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
- all_ids = {str(item["category_id"]) for item in categories}
- detached = [
- item
- for item in categories
- if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) > 1
- ]
- detached.sort(
- key=lambda item: (
- int(item.get("category_level") or 1),
- str(item["category_id"]),
- )
- )
- return detached
- def _build_category_depth_map(
- categories: list[dict[str, object]],
- children_map: dict[str, list[dict[str, object]]],
- roots: list[dict[str, object]],
- ) -> dict[str, int]:
- category_by_id = {str(item["category_id"]): item for item in categories}
- depth_map: dict[str, int] = {}
- def visit(category_id: str, fallback_level: int) -> None:
- if category_id in depth_map:
- return
- category = category_by_id.get(category_id)
- level = category.get("category_level") if category else None
- resolved_level = int(level) if level is not None else fallback_level
- depth_map[category_id] = resolved_level
- for child in children_map.get(category_id, []):
- visit(str(child["category_id"]), resolved_level + 1)
- for root in roots:
- visit(str(root["category_id"]), int(root.get("category_level") or 1))
- for item in _find_detached_categories(categories):
- category_id = str(item["category_id"])
- if category_id not in depth_map:
- visit(category_id, int(item.get("category_level") or 1))
- for category in categories:
- category_id = str(category["category_id"])
- if category_id not in depth_map:
- depth_map[category_id] = int(category.get("category_level") or 1)
- return depth_map
- def _get_effective_category_level(
- category: dict[str, object],
- depth_map: dict[str, int],
- ) -> int:
- level = category.get("category_level")
- if level is not None:
- return int(level)
- return depth_map.get(str(category["category_id"]), 1)
- def _build_category_rov_scales_by_level(
- categories: list[dict[str, object]],
- ) -> dict[str, dict[str, float]]:
- children_map = _build_children_map(categories)
- roots = _find_root_categories(categories)
- depth_map = _build_category_depth_map(categories, children_map, roots)
- scales: dict[int, dict[str, float]] = {}
- for category in categories:
- rov_score = category.get("rov_score")
- if _is_zero_rov_score(rov_score):
- continue
- level = _get_effective_category_level(category, depth_map)
- score = float(rov_score)
- current = scales.get(level)
- if current is None:
- scales[level] = {"min": score, "max": score}
- continue
- current["min"] = min(current["min"], score)
- current["max"] = max(current["max"], score)
- return {str(level): value for level, value in sorted(scales.items())}
- def _serialize_category_row(row: object, child_category_ids: set[str]) -> dict[str, object]:
- category_id = str(row["category_id"] or "").strip()
- parent_id = str(row["parent_stable_id"] or "").strip() or None
- if parent_id:
- child_category_ids.add(parent_id)
- rov_score = row["rov_score"]
- return {
- "category_id": category_id,
- "parent_stable_id": parent_id,
- "category_name": row["category_name"],
- "category_level": row["category_level"],
- "vid_count": row["vid_count"],
- "rov_score": float(rov_score) if rov_score is not None else None,
- "is_leaf": False,
- }
- def _query_element_counts_by_category(
- session: object,
- element_base_table: str,
- element_effect_table: str,
- partition_dt: str,
- ) -> tuple[dict[str, dict[str, int]], int]:
- rows = session.execute(
- text(
- f"""
- SELECT
- b.stable_id,
- COUNT(*) AS total,
- SUM(
- CASE
- WHEN e.rov_score IS NULL OR e.rov_score = 0 THEN 1
- ELSE 0
- END
- ) AS invalid
- FROM {element_base_table} b
- LEFT JOIN {element_effect_table} e
- ON b.element_id = e.element_id
- AND e.dt = :dt
- WHERE b.stable_id IS NOT NULL
- AND TRIM(b.stable_id) <> ''
- GROUP BY b.stable_id
- """
- ),
- {"dt": partition_dt},
- ).mappings().all()
- counts_by_category: dict[str, dict[str, int]] = {}
- invalid_total = 0
- for row in rows:
- category_id = str(row["stable_id"] or "").strip()
- if not category_id:
- continue
- total = int(row["total"] or 0)
- invalid = int(row["invalid"] or 0)
- valid = max(total - invalid, 0)
- counts_by_category[category_id] = {"total": total, "valid": valid}
- invalid_total += invalid
- return counts_by_category, invalid_total
- def _query_global_element_rov_scale(
- session: object,
- element_effect_table: str,
- partition_dt: str,
- ) -> dict[str, float]:
- row = session.execute(
- text(
- f"""
- SELECT
- MIN(rov_score) AS min_rov_score,
- MAX(rov_score) AS max_rov_score
- FROM {element_effect_table}
- WHERE dt = :dt
- AND rov_score IS NOT NULL
- AND rov_score <> 0
- """
- ),
- {"dt": partition_dt},
- ).mappings().first()
- if not row or row["min_rov_score"] is None or row["max_rov_score"] is None:
- return {"min": 0.0, "max": 0.0}
- return {
- "min": float(row["min_rov_score"]),
- "max": float(row["max_rov_score"]),
- }
- def query_available_dates() -> list[str]:
- category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
- with SessionLocal() as session:
- rows = session.execute(
- text(
- f"""
- SELECT DISTINCT dt
- FROM {category_effect_table}
- ORDER BY dt DESC
- LIMIT 366
- """
- )
- ).all()
- return [str(row[0]) for row in rows if row[0]]
- def query_vertical_category_tree(dt: str | None = None) -> dict[str, object]:
- partition_dt = _resolve_partition_dt(dt)
- category_base_table = _safe_identifier(settings.vertical_category_base_table)
- category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
- element_base_table = _safe_identifier(settings.substance_element_base_table)
- element_effect_table = _safe_identifier(settings.substance_element_effect_table)
- category_sql = text(
- f"""
- SELECT
- b.category_id,
- b.parent_stable_id,
- b.category_name,
- b.category_level,
- e.vid_count,
- e.rov_score
- FROM {category_base_table} b
- LEFT JOIN {category_effect_table} e
- ON b.category_id = e.category_id
- AND e.dt = :dt
- ORDER BY b.category_level ASC, b.category_id ASC
- """
- )
- with SessionLocal() as session:
- category_rows = session.execute(category_sql, {"dt": partition_dt}).mappings().all()
- element_counts_by_category, element_invalid_count = _query_element_counts_by_category(
- session,
- element_base_table,
- element_effect_table,
- partition_dt,
- )
- element_rov_scale = _query_global_element_rov_scale(
- session,
- element_effect_table,
- partition_dt,
- )
- categories: list[dict[str, object]] = []
- child_category_ids: set[str] = set()
- for row in category_rows:
- categories.append(_serialize_category_row(row, child_category_ids))
- for item in categories:
- item["is_leaf"] = item["category_id"] not in child_category_ids
- return {
- "dt": partition_dt,
- "available_dates": query_available_dates(),
- "element_invalid_count": element_invalid_count,
- "element_counts_by_category": element_counts_by_category,
- "rov_scales": {
- "category_by_level": _build_category_rov_scales_by_level(categories),
- "element": element_rov_scale,
- },
- "categories": categories,
- }
- def query_category_elements(category_id: str, dt: str | None = None) -> dict[str, object]:
- normalized_category_id = category_id.strip()
- if not normalized_category_id:
- raise ValueError("category_id is required")
- partition_dt = _resolve_partition_dt(dt)
- element_base_table = _safe_identifier(settings.substance_element_base_table)
- element_effect_table = _safe_identifier(settings.substance_element_effect_table)
- element_sql = text(
- f"""
- SELECT
- b.element_id,
- b.stable_id,
- b.element_name,
- e.vid_count,
- e.rov_score
- FROM {element_base_table} b
- LEFT JOIN {element_effect_table} e
- ON b.element_id = e.element_id
- AND e.dt = :dt
- WHERE b.stable_id = :category_id
- ORDER BY b.element_id ASC
- """
- )
- with SessionLocal() as session:
- element_rows = session.execute(
- element_sql,
- {"dt": partition_dt, "category_id": normalized_category_id},
- ).mappings().all()
- rov_scale = _query_global_element_rov_scale(session, element_effect_table, partition_dt)
- elements: list[dict[str, object]] = []
- for row in element_rows:
- rov_score = row["rov_score"]
- elements.append(
- {
- "element_id": str(row["element_id"] or "").strip(),
- "stable_id": str(row["stable_id"] or "").strip() or None,
- "element_name": row["element_name"],
- "vid_count": row["vid_count"],
- "rov_score": float(rov_score) if rov_score is not None else None,
- }
- )
- return {
- "dt": partition_dt,
- "category_id": normalized_category_id,
- "elements": elements,
- "rov_scales": {
- "element": rov_scale,
- },
- }
|