فهرست منبع

修改需求汇总加载方式

xueyiming 1 هفته پیش
والد
کامیت
8aa9251ac6
3فایلهای تغییر یافته به همراه502 افزوده شده و 156 حذف شده
  1. 18 1
      app/api/routes.py
  2. 263 61
      app/services/vertical_category_tree_service.py
  3. 221 94
      frontend/src/CategoryEffectTreeApp.tsx

+ 18 - 1
app/api/routes.py

@@ -30,7 +30,10 @@ from app.services.strategy_config_service import (
     set_strategy_config_active,
     update_strategy_config,
 )
-from app.services.vertical_category_tree_service import query_vertical_category_tree
+from app.services.vertical_category_tree_service import (
+    query_category_elements,
+    query_vertical_category_tree,
+)
 from app.sync.experiment_demand_pool_write import run_experiment_hourly_write
 from app.utils.excel_export import build_content_disposition, rows_to_excel_bytes
 
@@ -549,6 +552,20 @@ async def get_vertical_category_tree(
         raise HTTPException(status_code=400, detail=str(exc)) from exc
 
 
+@router.get("/vertical-category/categories/{category_id}/elements")
+async def get_vertical_category_elements(
+    category_id: str,
+    dt: str | None = Query(
+        default=None,
+        description="效果分区日期: yyyymmdd 或 yyyy-mm-dd;未传则取最新分区",
+    ),
+) -> dict[str, object]:
+    try:
+        return query_category_elements(category_id=category_id, dt=dt)
+    except ValueError as exc:
+        raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
 @router.get("/demand-pool/strategies")
 async def get_demand_pool_strategies(
     start_dt: str | None = Query(default=None, description="开始日期: yyyymmdd 或 yyyy-mm-dd"),

+ 263 - 61
app/services/vertical_category_tree_service.py

@@ -44,6 +44,211 @@ def _resolve_partition_dt(dt: str | None) -> str:
     return latest
 
 
+def _is_zero_rov_score(rov_score: object) -> bool:
+    if rov_score is None:
+        return True
+    return float(rov_score) == 0.0
+
+
+def _build_children_map(categories: list[dict[str, object]]) -> dict[str, list[dict[str, object]]]:
+    children_map: dict[str, list[dict[str, object]]] = {}
+    for category in categories:
+        parent_id = category.get("parent_stable_id")
+        if not parent_id:
+            continue
+        parent_key = str(parent_id)
+        children_map.setdefault(parent_key, []).append(category)
+    for siblings in children_map.values():
+        siblings.sort(key=lambda item: str(item["category_id"]))
+    return children_map
+
+
+def _is_missing_parent(category: dict[str, object], all_ids: set[str]) -> bool:
+    parent_id = category.get("parent_stable_id")
+    if not parent_id:
+        return True
+    return str(parent_id) not in all_ids
+
+
+def _find_root_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
+    all_ids = {str(item["category_id"]) for item in categories}
+    roots = [
+        item
+        for item in categories
+        if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) <= 1
+    ]
+    roots.sort(key=lambda item: str(item["category_id"]))
+    return roots
+
+
+def _find_detached_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
+    all_ids = {str(item["category_id"]) for item in categories}
+    detached = [
+        item
+        for item in categories
+        if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) > 1
+    ]
+    detached.sort(
+        key=lambda item: (
+            int(item.get("category_level") or 1),
+            str(item["category_id"]),
+        )
+    )
+    return detached
+
+
+def _build_category_depth_map(
+    categories: list[dict[str, object]],
+    children_map: dict[str, list[dict[str, object]]],
+    roots: list[dict[str, object]],
+) -> dict[str, int]:
+    category_by_id = {str(item["category_id"]): item for item in categories}
+    depth_map: dict[str, int] = {}
+
+    def visit(category_id: str, fallback_level: int) -> None:
+        if category_id in depth_map:
+            return
+        category = category_by_id.get(category_id)
+        level = category.get("category_level") if category else None
+        resolved_level = int(level) if level is not None else fallback_level
+        depth_map[category_id] = resolved_level
+        for child in children_map.get(category_id, []):
+            visit(str(child["category_id"]), resolved_level + 1)
+
+    for root in roots:
+        visit(str(root["category_id"]), int(root.get("category_level") or 1))
+    for item in _find_detached_categories(categories):
+        category_id = str(item["category_id"])
+        if category_id not in depth_map:
+            visit(category_id, int(item.get("category_level") or 1))
+    for category in categories:
+        category_id = str(category["category_id"])
+        if category_id not in depth_map:
+            depth_map[category_id] = int(category.get("category_level") or 1)
+    return depth_map
+
+
+def _get_effective_category_level(
+    category: dict[str, object],
+    depth_map: dict[str, int],
+) -> int:
+    level = category.get("category_level")
+    if level is not None:
+        return int(level)
+    return depth_map.get(str(category["category_id"]), 1)
+
+
+def _build_category_rov_scales_by_level(
+    categories: list[dict[str, object]],
+) -> dict[str, dict[str, float]]:
+    children_map = _build_children_map(categories)
+    roots = _find_root_categories(categories)
+    depth_map = _build_category_depth_map(categories, children_map, roots)
+
+    scales: dict[int, dict[str, float]] = {}
+    for category in categories:
+        rov_score = category.get("rov_score")
+        if _is_zero_rov_score(rov_score):
+            continue
+        level = _get_effective_category_level(category, depth_map)
+        score = float(rov_score)
+        current = scales.get(level)
+        if current is None:
+            scales[level] = {"min": score, "max": score}
+            continue
+        current["min"] = min(current["min"], score)
+        current["max"] = max(current["max"], score)
+    return {str(level): value for level, value in sorted(scales.items())}
+
+
+def _serialize_category_row(row: object, child_category_ids: set[str]) -> dict[str, object]:
+    category_id = str(row["category_id"] or "").strip()
+    parent_id = str(row["parent_stable_id"] or "").strip() or None
+    if parent_id:
+        child_category_ids.add(parent_id)
+    rov_score = row["rov_score"]
+    return {
+        "category_id": category_id,
+        "parent_stable_id": parent_id,
+        "category_name": row["category_name"],
+        "category_level": row["category_level"],
+        "vid_count": row["vid_count"],
+        "rov_score": float(rov_score) if rov_score is not None else None,
+        "is_leaf": False,
+    }
+
+
+def _query_element_counts_by_category(
+    session: object,
+    element_base_table: str,
+    element_effect_table: str,
+    partition_dt: str,
+) -> tuple[dict[str, dict[str, int]], int]:
+    rows = session.execute(
+        text(
+            f"""
+            SELECT
+                b.stable_id,
+                COUNT(*) AS total,
+                SUM(
+                    CASE
+                        WHEN e.rov_score IS NULL OR e.rov_score = 0 THEN 1
+                        ELSE 0
+                    END
+                ) AS invalid
+            FROM {element_base_table} b
+            LEFT JOIN {element_effect_table} e
+                ON b.element_id = e.element_id
+                AND e.dt = :dt
+            WHERE b.stable_id IS NOT NULL
+                AND TRIM(b.stable_id) <> ''
+            GROUP BY b.stable_id
+            """
+        ),
+        {"dt": partition_dt},
+    ).mappings().all()
+
+    counts_by_category: dict[str, dict[str, int]] = {}
+    invalid_total = 0
+    for row in rows:
+        category_id = str(row["stable_id"] or "").strip()
+        if not category_id:
+            continue
+        total = int(row["total"] or 0)
+        invalid = int(row["invalid"] or 0)
+        valid = max(total - invalid, 0)
+        counts_by_category[category_id] = {"total": total, "valid": valid}
+        invalid_total += invalid
+    return counts_by_category, invalid_total
+
+
+def _query_global_element_rov_scale(
+    session: object,
+    element_effect_table: str,
+    partition_dt: str,
+) -> dict[str, float]:
+    row = session.execute(
+        text(
+            f"""
+            SELECT
+                MIN(rov_score) AS min_rov_score,
+                MAX(rov_score) AS max_rov_score
+            FROM {element_effect_table}
+            WHERE dt = :dt
+                AND rov_score IS NOT NULL
+                AND rov_score <> 0
+            """
+        ),
+        {"dt": partition_dt},
+    ).mappings().first()
+    if not row or row["min_rov_score"] is None or row["max_rov_score"] is None:
+        return {"min": 0.0, "max": 0.0}
+    return {
+        "min": float(row["min_rov_score"]),
+        "max": float(row["max_rov_score"]),
+    }
+
+
 def query_available_dates() -> list[str]:
     category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
     with SessionLocal() as session:
@@ -75,12 +280,8 @@ def query_vertical_category_tree(dt: str | None = None) -> dict[str, object]:
             b.parent_stable_id,
             b.category_name,
             b.category_level,
-            b.dimension,
-            b.classified_as,
             e.vid_count,
-            e.rov_score,
-            e.str_score,
-            e.ros_score
+            e.rov_score
         FROM {category_base_table} b
         LEFT JOIN {category_effect_table} e
             ON b.category_id = e.category_id
@@ -89,92 +290,93 @@ def query_vertical_category_tree(dt: str | None = None) -> dict[str, object]:
         """
     )
 
+    with SessionLocal() as session:
+        category_rows = session.execute(category_sql, {"dt": partition_dt}).mappings().all()
+        element_counts_by_category, element_invalid_count = _query_element_counts_by_category(
+            session,
+            element_base_table,
+            element_effect_table,
+            partition_dt,
+        )
+        element_rov_scale = _query_global_element_rov_scale(
+            session,
+            element_effect_table,
+            partition_dt,
+        )
+
+    categories: list[dict[str, object]] = []
+    child_category_ids: set[str] = set()
+    for row in category_rows:
+        categories.append(_serialize_category_row(row, child_category_ids))
+
+    for item in categories:
+        item["is_leaf"] = item["category_id"] not in child_category_ids
+
+    return {
+        "dt": partition_dt,
+        "available_dates": query_available_dates(),
+        "element_invalid_count": element_invalid_count,
+        "element_counts_by_category": element_counts_by_category,
+        "rov_scales": {
+            "category_by_level": _build_category_rov_scales_by_level(categories),
+            "element": element_rov_scale,
+        },
+        "categories": categories,
+    }
+
+
+def query_category_elements(category_id: str, dt: str | None = None) -> dict[str, object]:
+    normalized_category_id = category_id.strip()
+    if not normalized_category_id:
+        raise ValueError("category_id is required")
+
+    partition_dt = _resolve_partition_dt(dt)
+
+    element_base_table = _safe_identifier(settings.substance_element_base_table)
+    element_effect_table = _safe_identifier(settings.substance_element_effect_table)
+
     element_sql = text(
         f"""
         SELECT
             b.element_id,
             b.stable_id,
             b.element_name,
-            b.dimension,
-            b.classified_as,
             e.vid_count,
-            e.rov_score,
-            e.str_score,
-            e.ros_score
+            e.rov_score
         FROM {element_base_table} b
         LEFT JOIN {element_effect_table} e
             ON b.element_id = e.element_id
             AND e.dt = :dt
-        ORDER BY b.stable_id ASC, b.element_id ASC
+        WHERE b.stable_id = :category_id
+        ORDER BY b.element_id ASC
         """
     )
 
     with SessionLocal() as session:
-        category_rows = session.execute(category_sql, {"dt": partition_dt}).mappings().all()
-        element_rows = session.execute(element_sql, {"dt": partition_dt}).mappings().all()
-
-    categories: list[dict[str, object]] = []
-    child_category_ids: set[str] = set()
-    category_rov_values: list[float] = []
-
-    for row in category_rows:
-        category_id = str(row["category_id"] or "").strip()
-        parent_id = str(row["parent_stable_id"] or "").strip() or None
-        if parent_id:
-            child_category_ids.add(parent_id)
-        rov_score = row["rov_score"]
-        if rov_score is not None and float(rov_score) != 0.0:
-            category_rov_values.append(float(rov_score))
-        categories.append(
-            {
-                "category_id": category_id,
-                "parent_stable_id": parent_id,
-                "category_name": row["category_name"],
-                "category_level": row["category_level"],
-                "dimension": row["dimension"],
-                "classified_as": row["classified_as"],
-                "vid_count": row["vid_count"],
-                "rov_score": float(rov_score) if rov_score is not None else None,
-                "str_score": float(row["str_score"]) if row["str_score"] is not None else None,
-                "ros_score": float(row["ros_score"]) if row["ros_score"] is not None else None,
-            }
-        )
-
-    for item in categories:
-        item["is_leaf"] = item["category_id"] not in child_category_ids
+        element_rows = session.execute(
+            element_sql,
+            {"dt": partition_dt, "category_id": normalized_category_id},
+        ).mappings().all()
+        rov_scale = _query_global_element_rov_scale(session, element_effect_table, partition_dt)
 
     elements: list[dict[str, object]] = []
-    element_rov_values: list[float] = []
     for row in element_rows:
         rov_score = row["rov_score"]
-        if rov_score is not None and float(rov_score) != 0.0:
-            element_rov_values.append(float(rov_score))
         elements.append(
             {
                 "element_id": str(row["element_id"] or "").strip(),
                 "stable_id": str(row["stable_id"] or "").strip() or None,
                 "element_name": row["element_name"],
-                "dimension": row["dimension"],
-                "classified_as": row["classified_as"],
                 "vid_count": row["vid_count"],
                 "rov_score": float(rov_score) if rov_score is not None else None,
-                "str_score": float(row["str_score"]) if row["str_score"] is not None else None,
-                "ros_score": float(row["ros_score"]) if row["ros_score"] is not None else None,
             }
         )
 
-    category_min_rov = min(category_rov_values) if category_rov_values else 0.0
-    category_max_rov = max(category_rov_values) if category_rov_values else 0.0
-    element_min_rov = min(element_rov_values) if element_rov_values else 0.0
-    element_max_rov = max(element_rov_values) if element_rov_values else 0.0
-
     return {
         "dt": partition_dt,
-        "available_dates": query_available_dates(),
-        "category_min_rov_score": category_min_rov,
-        "category_max_rov_score": category_max_rov,
-        "element_min_rov_score": element_min_rov,
-        "element_max_rov_score": element_max_rov,
-        "categories": categories,
+        "category_id": normalized_category_id,
         "elements": elements,
+        "rov_scales": {
+            "element": rov_scale,
+        },
     }

+ 221 - 94
frontend/src/CategoryEffectTreeApp.tsx

@@ -53,15 +53,34 @@ type ElementNode = {
   rov_score: number | null;
 };
 
+type RovScale = { min: number; max: number };
+
+type RovScales = {
+  category_by_level: Record<string, RovScale>;
+  element: RovScale;
+};
+
+type ElementCountInfo = {
+  total: number;
+  valid: number;
+};
+
 type TreeResponse = {
   dt: string;
   available_dates: string[];
-  category_min_rov_score: number;
-  category_max_rov_score: number;
-  element_min_rov_score: number;
-  element_max_rov_score: number;
+  element_invalid_count: number;
+  element_counts_by_category: Record<string, ElementCountInfo>;
+  rov_scales: RovScales;
   categories: CategoryNode[];
+};
+
+type CategoryElementsResponse = {
+  dt: string;
+  category_id: string;
   elements: ElementNode[];
+  rov_scales: {
+    element: RovScale;
+  };
 };
 
 type LayoutKind = "category" | "element";
@@ -183,25 +202,6 @@ function buildChildrenMap(categories: CategoryNode[]): Map<string, CategoryNode[
   return map;
 }
 
-function buildElementsByCategory(elements: ElementNode[]): Map<string, ElementNode[]> {
-  const map = new Map<string, ElementNode[]>();
-  for (const element of elements) {
-    const categoryId = element.stable_id;
-    if (!categoryId) {
-      continue;
-    }
-    const siblings = map.get(categoryId) ?? [];
-    siblings.push(element);
-    map.set(categoryId, siblings);
-  }
-  for (const siblings of map.values()) {
-    siblings.sort((a, b) =>
-      a.element_id.localeCompare(b.element_id, undefined, { numeric: true }),
-    );
-  }
-  return map;
-}
-
 function isMissingParent(category: CategoryNode, allIds: Set<string>): boolean {
   return !category.parent_stable_id || !allIds.has(category.parent_stable_id);
 }
@@ -302,49 +302,53 @@ function isCategoryVisibleByExpandFilter(
   return getEffectiveCategoryLevel(category, depthMap) <= levelFilter;
 }
 
-type RovScale = { min: number; max: number };
-
-function buildCategoryRovScalesByLevel(
-  categories: CategoryNode[],
-  depthMap: Map<string, number>,
+function parseCategoryRovScalesByLevel(
+  scales: Record<string, RovScale> | undefined,
 ): Map<number, RovScale> {
-  const scales = new Map<number, RovScale>();
-  for (const category of categories) {
-    if (isZeroRovScore(category.rov_score) || category.rov_score === null) {
-      continue;
-    }
-    const level = getEffectiveCategoryLevel(category, depthMap);
-    const score = category.rov_score;
-    const current = scales.get(level);
-    if (!current) {
-      scales.set(level, { min: score, max: score });
-      continue;
-    }
-    current.min = Math.min(current.min, score);
-    current.max = Math.max(current.max, score);
+  const map = new Map<number, RovScale>();
+  if (!scales) {
+    return map;
   }
-  return scales;
+  for (const [level, scale] of Object.entries(scales)) {
+    map.set(Number(level), scale);
+  }
+  return map;
 }
 
-function buildElementRovScale(elements: ElementNode[]): RovScale {
-  let min = 0;
-  let max = 0;
-  let initialized = false;
-  for (const element of elements) {
-    if (isZeroRovScore(element.rov_score) || element.rov_score === null) {
-      continue;
-    }
-    const score = element.rov_score;
-    if (!initialized) {
-      min = score;
-      max = score;
-      initialized = true;
-      continue;
-    }
-    min = Math.min(min, score);
-    max = Math.max(max, score);
+function getElementCountInfo(
+  counts: Record<string, ElementCountInfo> | undefined,
+  categoryId: string,
+): ElementCountInfo {
+  return counts?.[categoryId] ?? { total: 0, valid: 0 };
+}
+
+function getElementBadgeCount(
+  countInfo: ElementCountInfo,
+  loadedElements: ElementNode[] | undefined,
+  showInvalidNodes: boolean,
+): number {
+  if (loadedElements !== undefined) {
+    return showInvalidNodes
+      ? loadedElements.length
+      : loadedElements.filter((item) => !isZeroRovScore(item.rov_score)).length;
+  }
+  return showInvalidNodes ? countInfo.total : countInfo.valid;
+}
+
+function categoryHasVisibleElements(
+  countInfo: ElementCountInfo,
+  showInvalidNodes: boolean,
+): boolean {
+  return showInvalidNodes ? countInfo.total > 0 : countInfo.valid > 0;
+}
+
+function getTotalElementCount(
+  counts: Record<string, ElementCountInfo> | undefined,
+): number {
+  if (!counts) {
+    return 0;
   }
-  return { min, max };
+  return Object.values(counts).reduce((sum, item) => sum + item.total, 0);
 }
 
 function resolveNodeRovScale(
@@ -413,6 +417,7 @@ type TreeNodeCardProps = {
   elementRovScale: RovScale;
   categoryDepthMap: Map<string, number>;
   active: boolean;
+  elementsLoading: boolean;
   onToggleCategory: () => void;
   onToggleElements: () => void;
 };
@@ -423,6 +428,7 @@ function TreeNodeCard({
   elementRovScale,
   categoryDepthMap,
   active,
+  elementsLoading,
   onToggleCategory,
   onToggleElements,
 }: TreeNodeCardProps) {
@@ -533,13 +539,14 @@ function TreeNodeCard({
             <button
               type="button"
               className={`cet-node-badge cet-node-badge--element${node.elementsExpanded ? " cet-node-badge--expanded" : ""}`}
+              disabled={elementsLoading}
               onClick={(event) => {
                 event.stopPropagation();
                 onToggleElements();
               }}
             >
-              {node.elementBadgeCount} 元素
-              {node.elementsExpanded ? <CaretDownFilled /> : <CaretRightFilled />}
+              {elementsLoading ? "加载中" : `${node.elementBadgeCount} 元素`}
+              {!elementsLoading && (node.elementsExpanded ? <CaretDownFilled /> : <CaretRightFilled />)}
             </button>
           ) : null}
         </div>
@@ -554,6 +561,10 @@ export default function CategoryEffectTreeApp() {
   const [loading, setLoading] = useState(false);
   const [error, setError] = useState("");
   const [data, setData] = useState<TreeResponse | null>(null);
+  const [loadedElementsByCategory, setLoadedElementsByCategory] = useState<
+    Map<string, ElementNode[]>
+  >(new Map());
+  const [loadingElementCategoryId, setLoadingElementCategoryId] = useState<string | null>(null);
   const [expandedCategoryIds, setExpandedCategoryIds] = useState<Set<string>>(new Set());
   const [expandedElementParents, setExpandedElementParents] = useState<Set<string>>(new Set());
   const [activeKey, setActiveKey] = useState<string | null>(null);
@@ -585,7 +596,18 @@ export default function CategoryEffectTreeApp() {
         throw new Error(detail || `HTTP ${response.status}`);
       }
       const payload = (await response.json()) as TreeResponse;
-      setData(payload);
+      setData({
+        ...payload,
+        categories: payload.categories ?? [],
+        element_counts_by_category: payload.element_counts_by_category ?? {},
+        element_invalid_count: payload.element_invalid_count ?? 0,
+        rov_scales: payload.rov_scales ?? {
+          category_by_level: {},
+          element: { min: 0, max: 0 },
+        },
+      });
+      setLoadedElementsByCategory(new Map());
+      setLoadingElementCategoryId(null);
       setAppliedDt(payload.dt);
       setSelectedDate(dayjs(payload.dt, "YYYYMMDD"));
       setExpandedCategoryIds(new Set());
@@ -598,11 +620,52 @@ export default function CategoryEffectTreeApp() {
         queryError instanceof Error ? queryError.message : "查询失败,请重试",
       );
       setData(null);
+      setLoadedElementsByCategory(new Map());
+      setExpandedCategoryIds(new Set());
+      setExpandedElementParents(new Set());
+      setLoadingElementCategoryId(null);
     } finally {
       setLoading(false);
     }
   }, []);
 
+  const fetchCategoryElements = useCallback(
+    async (categoryId: string, dt: string) => {
+      setLoadingElementCategoryId(categoryId);
+      try {
+        const resolvedBase = getResolvedApiBaseUrl();
+        const baseWithSlash = resolvedBase.endsWith("/")
+          ? resolvedBase
+          : `${resolvedBase}/`;
+        const url = new URL(
+          `vertical-category/categories/${encodeURIComponent(categoryId)}/elements`,
+          baseWithSlash,
+        );
+        url.searchParams.set("dt", dt);
+        const response = await fetch(url.toString(), {
+          method: "GET",
+          headers: { Accept: "application/json" },
+        });
+        if (!response.ok) {
+          const detail = await response.text();
+          throw new Error(detail || `HTTP ${response.status}`);
+        }
+        const payload = (await response.json()) as CategoryElementsResponse;
+        setLoadedElementsByCategory((prev) => {
+          const next = new Map(prev);
+          next.set(categoryId, payload.elements ?? []);
+          return next;
+        });
+        return payload.elements ?? [];
+      } finally {
+        setLoadingElementCategoryId((current) =>
+          current === categoryId ? null : current,
+        );
+      }
+    },
+    [],
+  );
+
   useEffect(() => {
     void fetchTree();
   }, [fetchTree]);
@@ -611,10 +674,29 @@ export default function CategoryEffectTreeApp() {
     () => buildChildrenMap(data?.categories ?? []),
     [data?.categories],
   );
-  const elementsByCategory = useMemo(
-    () => buildElementsByCategory(data?.elements ?? []),
-    [data?.elements],
-  );
+  const elementsByCategory = useMemo(() => {
+    const map = new Map<string, ElementNode[]>();
+    for (const [categoryId, elements] of loadedElementsByCategory.entries()) {
+      map.set(categoryId, elements);
+    }
+    return map;
+  }, [loadedElementsByCategory]);
+  const elementLookup = useMemo(() => {
+    const map = new Map<string, ElementNode>();
+    for (const elements of loadedElementsByCategory.values()) {
+      for (const element of elements) {
+        map.set(element.element_id, element);
+      }
+    }
+    return map;
+  }, [loadedElementsByCategory]);
+  const categoryLookup = useMemo(() => {
+    const map = new Map<string, CategoryNode>();
+    for (const category of data?.categories ?? []) {
+      map.set(category.category_id, category);
+    }
+    return map;
+  }, [data?.categories]);
   const roots = useMemo(
     () => findRootCategories(data?.categories ?? []),
     [data?.categories],
@@ -756,7 +838,12 @@ export default function CategoryEffectTreeApp() {
     ): void => {
       const layoutDepth = getLayoutDepth(category, categoryDepthMap);
       const childCategories = childrenMap.get(category.category_id) ?? [];
-      const allChildElements = elementsByCategory.get(category.category_id) ?? [];
+      const elementCountInfo = getElementCountInfo(
+        data.element_counts_by_category,
+        category.category_id,
+      );
+      const loadedChildElements = elementsByCategory.get(category.category_id);
+      const allChildElements = loadedChildElements ?? [];
       const childElements = showInvalidNodes
         ? allChildElements
         : allChildElements.filter((item) => !isZeroRovScore(item.rov_score));
@@ -765,19 +852,25 @@ export default function CategoryEffectTreeApp() {
         : childCategories.filter((item) => !isZeroRovScore(item.rov_score));
 
       const hasCategoryChildren = childCategories.length > 0;
-      const hasElements = childElements.length > 0;
+      const hasElements = categoryHasVisibleElements(elementCountInfo, showInvalidNodes);
       const categoryExpanded = expandedCategoryIds.has(category.category_id);
       const elementsExpanded = expandedElementParents.has(category.category_id);
       const showSelf = showInvalidNodes || !isZeroRovScore(category.rov_score);
       const passThrough = !showSelf;
       const traverseCategories = passThrough || categoryExpanded;
+      const elementBadgeCount = getElementBadgeCount(
+        elementCountInfo,
+        loadedChildElements,
+        showInvalidNodes,
+      );
 
       const categoryChildKeys =
         traverseCategories && hasCategoryChildren
           ? childCategories.map((child) => `category:${child.category_id}`)
           : [];
       const shouldShowElements =
-        hasElements && (elementsExpanded || (passThrough && childElements.length > 0));
+        childElements.length > 0 &&
+        (elementsExpanded || (passThrough && childElements.length > 0));
       const elementChildKeys = shouldShowElements
         ? childElements.map((element) => `element:${element.element_id}`)
         : [];
@@ -793,7 +886,7 @@ export default function CategoryEffectTreeApp() {
           categoryExpanded,
           elementsExpanded,
           categoryBadgeCount: visibleChildCategories.length,
-          elementBadgeCount: childElements.length,
+          elementBadgeCount,
         });
         linkParentKey = node.key;
       }
@@ -801,7 +894,7 @@ export default function CategoryEffectTreeApp() {
       if (traverseCategories) {
         for (const childKey of categoryChildKeys) {
           const childId = childKey.slice("category:".length);
-          const childCategory = data.categories.find((item) => item.category_id === childId);
+          const childCategory = categoryLookup.get(childId);
           if (childCategory) {
             buildVisibleTree(childCategory, linkParentKey);
           }
@@ -815,9 +908,7 @@ export default function CategoryEffectTreeApp() {
             ? Math.min(
                 ...categoryChildKeys.map((childKey) => {
                   const childId = childKey.slice("category:".length);
-                  const childCategory = data.categories.find(
-                    (item) => item.category_id === childId,
-                  );
+                  const childCategory = categoryLookup.get(childId);
                   return childCategory
                     ? getLayoutDepth(childCategory, categoryDepthMap)
                     : parentLayoutDepth + 1;
@@ -826,7 +917,7 @@ export default function CategoryEffectTreeApp() {
             : parentLayoutDepth + 1;
         for (const childKey of elementChildKeys) {
           const childId = childKey.slice("element:".length);
-          const childElement = data.elements.find((item) => item.element_id === childId);
+          const childElement = elementLookup.get(childId);
           if (childElement) {
             buildElementNode(childElement, elementLayoutDepth, linkParentKey);
           }
@@ -988,7 +1079,7 @@ export default function CategoryEffectTreeApp() {
       canvasWidth: (maxDepth + 1) * (NODE_WIDTH + COLUMN_GAP) + CANVAS_PADDING * 2,
       canvasHeight,
     };
-  }, [data, roots, detachedCategories, childrenMap, elementsByCategory, expandedCategoryIds, expandedElementParents, showInvalidNodes, expandLevelFilter, categoryDepthMap]);
+  }, [data, roots, detachedCategories, childrenMap, elementsByCategory, categoryLookup, elementLookup, expandedCategoryIds, expandedElementParents, showInvalidNodes, expandLevelFilter, categoryDepthMap]);
 
   const fitToView = useCallback(() => {
     const viewport = viewportRef.current;
@@ -1102,15 +1193,47 @@ export default function CategoryEffectTreeApp() {
       return;
     }
 
-    setExpandedElementParents((prev) => {
-      const next = new Set(prev);
-      if (next.has(category.category_id)) {
-        next.delete(category.category_id);
-      } else {
-        next.add(category.category_id);
-      }
-      return next;
-    });
+    const categoryId = category.category_id;
+    const willCollapse = expandedElementParents.has(categoryId);
+    if (willCollapse) {
+      setExpandedElementParents((prev) => {
+        const next = new Set(prev);
+        next.delete(categoryId);
+        return next;
+      });
+      return;
+    }
+
+    if (loadingElementCategoryId === categoryId) {
+      return;
+    }
+
+    const expandCategory = () => {
+      setExpandedElementParents((prev) => {
+        const next = new Set(prev);
+        next.add(categoryId);
+        return next;
+      });
+    };
+
+    if (loadedElementsByCategory.has(categoryId)) {
+      expandCategory();
+      return;
+    }
+
+    if (!appliedDt) {
+      return;
+    }
+
+    void fetchCategoryElements(categoryId, appliedDt)
+      .then(() => {
+        expandCategory();
+      })
+      .catch((queryError) => {
+        setError(
+          queryError instanceof Error ? queryError.message : "元素加载失败,请重试",
+        );
+      });
   };
 
   const handleCollapseAll = () => {
@@ -1132,12 +1255,12 @@ export default function CategoryEffectTreeApp() {
   };
 
   const categoryRovScalesByLevel = useMemo(
-    () => buildCategoryRovScalesByLevel(data?.categories ?? [], categoryDepthMap),
-    [data?.categories, categoryDepthMap],
+    () => parseCategoryRovScalesByLevel(data?.rov_scales?.category_by_level),
+    [data?.rov_scales?.category_by_level],
   );
   const elementRovScale = useMemo(
-    () => buildElementRovScale(data?.elements ?? []),
-    [data?.elements],
+    () => data?.rov_scales?.element ?? { min: 0, max: 0 },
+    [data?.rov_scales?.element],
   );
 
   const invalidNodeCount = useMemo(() => {
@@ -1145,8 +1268,7 @@ export default function CategoryEffectTreeApp() {
       return 0;
     }
     const invalidCategories = data.categories.filter((item) => isZeroRovScore(item.rov_score)).length;
-    const invalidElements = data.elements.filter((item) => isZeroRovScore(item.rov_score)).length;
-    return invalidCategories + invalidElements;
+    return invalidCategories + (data.element_invalid_count ?? 0);
   }, [data]);
 
   return (
@@ -1227,7 +1349,7 @@ export default function CategoryEffectTreeApp() {
           <div className="cet-stats">
             <span>日期 {formatDtLabel(appliedDt)}</span>
             <span>分类 {data?.categories.length ?? 0}</span>
-            <span>元素 {data?.elements.length ?? 0}</span>
+            <span>元素 {getTotalElementCount(data?.element_counts_by_category)}</span>
             <span>节点 {layoutNodes.length}</span>
           </div>
         ) : null}
@@ -1343,6 +1465,11 @@ export default function CategoryEffectTreeApp() {
                           elementRovScale={elementRovScale}
                           categoryDepthMap={categoryDepthMap}
                           active={activeKey === node.key}
+                          elementsLoading={
+                            node.kind === "category" &&
+                            node.category !== undefined &&
+                            loadingElementCategoryId === node.category.category_id
+                          }
                           onToggleCategory={() => handleToggleCategory(node)}
                           onToggleElements={() => handleToggleElements(node)}
                         />