vertical_category_tree_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """垂直领域分类树:查询分类/元素基础信息与效果得分。"""
  2. import re
  3. from sqlalchemy import text
  4. from app.core.config import settings
  5. from app.db.mysql import SessionLocal
  6. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
  7. DATE_RE = re.compile(r"^\d{8}$")
  8. def _safe_identifier(name: str) -> str:
  9. if not IDENTIFIER_RE.match(name):
  10. raise ValueError(f"invalid sql identifier: {name}")
  11. return name
  12. def _normalize_date(date_value: str | None) -> str | None:
  13. if not date_value:
  14. return None
  15. normalized = date_value.replace("-", "").strip()
  16. if not normalized:
  17. return None
  18. if not DATE_RE.match(normalized):
  19. raise ValueError("date must be yyyymmdd or yyyy-mm-dd")
  20. return normalized
  21. def _resolve_partition_dt(dt: str | None) -> str:
  22. normalized = _normalize_date(dt)
  23. if normalized:
  24. return normalized
  25. category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
  26. with SessionLocal() as session:
  27. row = session.execute(
  28. text(f"SELECT MAX(dt) FROM {category_effect_table}")
  29. ).first()
  30. latest = str(row[0] or "").strip() if row else ""
  31. if not latest:
  32. raise ValueError("暂无效果数据分区,请先同步垂直领域分类数据")
  33. return latest
  34. def _is_zero_rov_score(rov_score: object) -> bool:
  35. if rov_score is None:
  36. return True
  37. return float(rov_score) == 0.0
  38. def _build_children_map(categories: list[dict[str, object]]) -> dict[str, list[dict[str, object]]]:
  39. children_map: dict[str, list[dict[str, object]]] = {}
  40. for category in categories:
  41. parent_id = category.get("parent_stable_id")
  42. if not parent_id:
  43. continue
  44. parent_key = str(parent_id)
  45. children_map.setdefault(parent_key, []).append(category)
  46. for siblings in children_map.values():
  47. siblings.sort(key=lambda item: str(item["category_id"]))
  48. return children_map
  49. def _is_missing_parent(category: dict[str, object], all_ids: set[str]) -> bool:
  50. parent_id = category.get("parent_stable_id")
  51. if not parent_id:
  52. return True
  53. return str(parent_id) not in all_ids
  54. def _find_root_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
  55. all_ids = {str(item["category_id"]) for item in categories}
  56. roots = [
  57. item
  58. for item in categories
  59. if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) <= 1
  60. ]
  61. roots.sort(key=lambda item: str(item["category_id"]))
  62. return roots
  63. def _find_detached_categories(categories: list[dict[str, object]]) -> list[dict[str, object]]:
  64. all_ids = {str(item["category_id"]) for item in categories}
  65. detached = [
  66. item
  67. for item in categories
  68. if _is_missing_parent(item, all_ids) and int(item.get("category_level") or 1) > 1
  69. ]
  70. detached.sort(
  71. key=lambda item: (
  72. int(item.get("category_level") or 1),
  73. str(item["category_id"]),
  74. )
  75. )
  76. return detached
  77. def _build_category_depth_map(
  78. categories: list[dict[str, object]],
  79. children_map: dict[str, list[dict[str, object]]],
  80. roots: list[dict[str, object]],
  81. ) -> dict[str, int]:
  82. category_by_id = {str(item["category_id"]): item for item in categories}
  83. depth_map: dict[str, int] = {}
  84. def visit(category_id: str, fallback_level: int) -> None:
  85. if category_id in depth_map:
  86. return
  87. category = category_by_id.get(category_id)
  88. level = category.get("category_level") if category else None
  89. resolved_level = int(level) if level is not None else fallback_level
  90. depth_map[category_id] = resolved_level
  91. for child in children_map.get(category_id, []):
  92. visit(str(child["category_id"]), resolved_level + 1)
  93. for root in roots:
  94. visit(str(root["category_id"]), int(root.get("category_level") or 1))
  95. for item in _find_detached_categories(categories):
  96. category_id = str(item["category_id"])
  97. if category_id not in depth_map:
  98. visit(category_id, int(item.get("category_level") or 1))
  99. for category in categories:
  100. category_id = str(category["category_id"])
  101. if category_id not in depth_map:
  102. depth_map[category_id] = int(category.get("category_level") or 1)
  103. return depth_map
  104. def _get_effective_category_level(
  105. category: dict[str, object],
  106. depth_map: dict[str, int],
  107. ) -> int:
  108. level = category.get("category_level")
  109. if level is not None:
  110. return int(level)
  111. return depth_map.get(str(category["category_id"]), 1)
  112. def _build_category_rov_scales_by_level(
  113. categories: list[dict[str, object]],
  114. ) -> dict[str, dict[str, float]]:
  115. children_map = _build_children_map(categories)
  116. roots = _find_root_categories(categories)
  117. depth_map = _build_category_depth_map(categories, children_map, roots)
  118. scales: dict[int, dict[str, float]] = {}
  119. for category in categories:
  120. rov_score = category.get("rov_score")
  121. if _is_zero_rov_score(rov_score):
  122. continue
  123. level = _get_effective_category_level(category, depth_map)
  124. score = float(rov_score)
  125. current = scales.get(level)
  126. if current is None:
  127. scales[level] = {"min": score, "max": score}
  128. continue
  129. current["min"] = min(current["min"], score)
  130. current["max"] = max(current["max"], score)
  131. return {str(level): value for level, value in sorted(scales.items())}
  132. def _serialize_category_row(row: object, child_category_ids: set[str]) -> dict[str, object]:
  133. category_id = str(row["category_id"] or "").strip()
  134. parent_id = str(row["parent_stable_id"] or "").strip() or None
  135. if parent_id:
  136. child_category_ids.add(parent_id)
  137. rov_score = row["rov_score"]
  138. return {
  139. "category_id": category_id,
  140. "parent_stable_id": parent_id,
  141. "category_name": row["category_name"],
  142. "category_level": row["category_level"],
  143. "vid_count": row["vid_count"],
  144. "rov_score": float(rov_score) if rov_score is not None else None,
  145. "is_leaf": False,
  146. }
  147. def _query_element_counts_by_category(
  148. session: object,
  149. element_base_table: str,
  150. element_effect_table: str,
  151. partition_dt: str,
  152. ) -> tuple[dict[str, dict[str, int]], int]:
  153. rows = session.execute(
  154. text(
  155. f"""
  156. SELECT
  157. b.stable_id,
  158. COUNT(*) AS total,
  159. SUM(
  160. CASE
  161. WHEN e.rov_score IS NULL OR e.rov_score = 0 THEN 1
  162. ELSE 0
  163. END
  164. ) AS invalid
  165. FROM {element_base_table} b
  166. LEFT JOIN {element_effect_table} e
  167. ON b.element_id = e.element_id
  168. AND e.dt = :dt
  169. WHERE b.stable_id IS NOT NULL
  170. AND TRIM(b.stable_id) <> ''
  171. GROUP BY b.stable_id
  172. """
  173. ),
  174. {"dt": partition_dt},
  175. ).mappings().all()
  176. counts_by_category: dict[str, dict[str, int]] = {}
  177. invalid_total = 0
  178. for row in rows:
  179. category_id = str(row["stable_id"] or "").strip()
  180. if not category_id:
  181. continue
  182. total = int(row["total"] or 0)
  183. invalid = int(row["invalid"] or 0)
  184. valid = max(total - invalid, 0)
  185. counts_by_category[category_id] = {"total": total, "valid": valid}
  186. invalid_total += invalid
  187. return counts_by_category, invalid_total
  188. def _query_global_element_rov_scale(
  189. session: object,
  190. element_effect_table: str,
  191. partition_dt: str,
  192. ) -> dict[str, float]:
  193. row = session.execute(
  194. text(
  195. f"""
  196. SELECT
  197. MIN(rov_score) AS min_rov_score,
  198. MAX(rov_score) AS max_rov_score
  199. FROM {element_effect_table}
  200. WHERE dt = :dt
  201. AND rov_score IS NOT NULL
  202. AND rov_score <> 0
  203. """
  204. ),
  205. {"dt": partition_dt},
  206. ).mappings().first()
  207. if not row or row["min_rov_score"] is None or row["max_rov_score"] is None:
  208. return {"min": 0.0, "max": 0.0}
  209. return {
  210. "min": float(row["min_rov_score"]),
  211. "max": float(row["max_rov_score"]),
  212. }
  213. def query_available_dates() -> list[str]:
  214. category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
  215. with SessionLocal() as session:
  216. rows = session.execute(
  217. text(
  218. f"""
  219. SELECT DISTINCT dt
  220. FROM {category_effect_table}
  221. ORDER BY dt DESC
  222. LIMIT 366
  223. """
  224. )
  225. ).all()
  226. return [str(row[0]) for row in rows if row[0]]
  227. def query_vertical_category_tree(dt: str | None = None) -> dict[str, object]:
  228. partition_dt = _resolve_partition_dt(dt)
  229. category_base_table = _safe_identifier(settings.vertical_category_base_table)
  230. category_effect_table = _safe_identifier(settings.vertical_category_effect_table)
  231. element_base_table = _safe_identifier(settings.substance_element_base_table)
  232. element_effect_table = _safe_identifier(settings.substance_element_effect_table)
  233. category_sql = text(
  234. f"""
  235. SELECT
  236. b.category_id,
  237. b.parent_stable_id,
  238. b.category_name,
  239. b.category_level,
  240. e.vid_count,
  241. e.rov_score
  242. FROM {category_base_table} b
  243. LEFT JOIN {category_effect_table} e
  244. ON b.category_id = e.category_id
  245. AND e.dt = :dt
  246. ORDER BY b.category_level ASC, b.category_id ASC
  247. """
  248. )
  249. with SessionLocal() as session:
  250. category_rows = session.execute(category_sql, {"dt": partition_dt}).mappings().all()
  251. element_counts_by_category, element_invalid_count = _query_element_counts_by_category(
  252. session,
  253. element_base_table,
  254. element_effect_table,
  255. partition_dt,
  256. )
  257. element_rov_scale = _query_global_element_rov_scale(
  258. session,
  259. element_effect_table,
  260. partition_dt,
  261. )
  262. categories: list[dict[str, object]] = []
  263. child_category_ids: set[str] = set()
  264. for row in category_rows:
  265. categories.append(_serialize_category_row(row, child_category_ids))
  266. for item in categories:
  267. item["is_leaf"] = item["category_id"] not in child_category_ids
  268. return {
  269. "dt": partition_dt,
  270. "available_dates": query_available_dates(),
  271. "element_invalid_count": element_invalid_count,
  272. "element_counts_by_category": element_counts_by_category,
  273. "rov_scales": {
  274. "category_by_level": _build_category_rov_scales_by_level(categories),
  275. "element": element_rov_scale,
  276. },
  277. "categories": categories,
  278. }
  279. def query_category_elements(category_id: str, dt: str | None = None) -> dict[str, object]:
  280. normalized_category_id = category_id.strip()
  281. if not normalized_category_id:
  282. raise ValueError("category_id is required")
  283. partition_dt = _resolve_partition_dt(dt)
  284. element_base_table = _safe_identifier(settings.substance_element_base_table)
  285. element_effect_table = _safe_identifier(settings.substance_element_effect_table)
  286. element_sql = text(
  287. f"""
  288. SELECT
  289. b.element_id,
  290. b.stable_id,
  291. b.element_name,
  292. e.vid_count,
  293. e.rov_score
  294. FROM {element_base_table} b
  295. LEFT JOIN {element_effect_table} e
  296. ON b.element_id = e.element_id
  297. AND e.dt = :dt
  298. WHERE b.stable_id = :category_id
  299. ORDER BY b.element_id ASC
  300. """
  301. )
  302. with SessionLocal() as session:
  303. element_rows = session.execute(
  304. element_sql,
  305. {"dt": partition_dt, "category_id": normalized_category_id},
  306. ).mappings().all()
  307. rov_scale = _query_global_element_rov_scale(session, element_effect_table, partition_dt)
  308. elements: list[dict[str, object]] = []
  309. for row in element_rows:
  310. rov_score = row["rov_score"]
  311. elements.append(
  312. {
  313. "element_id": str(row["element_id"] or "").strip(),
  314. "stable_id": str(row["stable_id"] or "").strip() or None,
  315. "element_name": row["element_name"],
  316. "vid_count": row["vid_count"],
  317. "rov_score": float(rov_score) if rov_score is not None else None,
  318. }
  319. )
  320. return {
  321. "dt": partition_dt,
  322. "category_id": normalized_category_id,
  323. "elements": elements,
  324. "rov_scales": {
  325. "element": rov_scale,
  326. },
  327. }