Просмотр исходного кода

how 账号常量导入pattern数据库

liuzhiheng 1 месяц назад
Родитель
Сommit
5bf9ddc8d2

+ 14 - 0
examples_how/overall_derivation/data_export_from_db/database_readme.md

@@ -316,3 +316,17 @@ CREATE TABLE `topic_pattern_element` (
   KEY `ix_topic_pattern_element_post_id` (`post_id`),
   KEY `ix_topic_pattern_element_post_id` (`post_id`),
   KEY `ix_topic_pattern_element_execution_id` (`execution_id`)
   KEY `ix_topic_pattern_element_execution_id` (`execution_id`)
 ) ENGINE=InnoDB AUTO_INCREMENT=106913 DEFAULT CHARSET=utf8mb4;
 ) ENGINE=InnoDB AUTO_INCREMENT=106913 DEFAULT CHARSET=utf8mb4;
+
+#### 账号常量表
+CREATE TABLE `account_constant` (
+  `id` bigint(20) NOT NULL AUTO_INCREMENT,
+  `account_name` varchar(255) NOT NULL COMMENT '账号名称',
+  `constant_node_name` varchar(64) NOT NULL COMMENT '常量节点名称',
+  `constant_type` varchar(32) NOT NULL COMMENT '常量类型(全局常量/局部常量)',
+  `node_type` varchar(32) NOT NULL COMMENT '节点类型(class/ID)',
+  `source_type` varchar(32) DEFAULT NULL COMMENT '节点元素类型(实质/形式/意图)',
+  `tree_level` int(11) DEFAULT NULL COMMENT '树层级',
+  `create_time` datetime DEFAULT NULL ON UPDATE CURRENT_TIMESTAMP COMMENT '创建时间',
+  PRIMARY KEY (`id`),
+  KEY `idx_aname` (`account_name`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

+ 1 - 1
examples_how/overall_derivation/data_process/how_tree_data_process.py

@@ -695,4 +695,4 @@ def main(account_name) -> None:
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    main(account_name="家有大志")
+    main(account_name="空间点阵设计研究室")

+ 171 - 0
examples_how/overall_derivation/data_process/import_account_constant_to_db.py

@@ -0,0 +1,171 @@
+from __future__ import annotations
+
+import json
+import sys
+from pathlib import Path
+from typing import Any, Optional
+
+from dotenv import load_dotenv
+
+_EXAMPLES_HOW_DIR = Path(__file__).resolve().parents[2]
+if str(_EXAMPLES_HOW_DIR) not in sys.path:
+    sys.path.insert(0, str(_EXAMPLES_HOW_DIR))
+
+from db_utils.mysql_db import get_mysql_db  # noqa: E402
+
+_DB_SOURCE = "pattern"
+_TABLE = "account_constant"
+
+_SOURCE_TYPE_PREFIXES = ("实质", "形式", "意图")
+
+
+def _tree_dir(account_name: str) -> Path:
+    """人设树目录:../input/{account_name}/处理后数据/tree/"""
+    base = Path(__file__).resolve().parents[1]
+    return base / "input" / account_name / "处理后数据" / "tree"
+
+
+def _source_type_from_filename(stem: str) -> Optional[str]:
+    """从文件名解析节点元素类型,如 实质_point_tree_how -> 实质。"""
+    if not stem:
+        return None
+    for p in _SOURCE_TYPE_PREFIXES:
+        if stem == p or stem.startswith(p + "_"):
+            return p
+    return None
+
+
+def _node_type_raw(node: dict[str, Any]) -> str:
+    t = node.get("_type")
+    if t is None:
+        return ""
+    t_str = str(t).strip()
+    # DB field `node_type` maps from tree `_type`.
+    # - class -> category
+    # - ID -> element
+    if t_str == "class":
+        return "category"
+    if t_str == "ID":
+        return "element"
+    return t_str
+
+
+def _constant_kind(node: dict[str, Any]) -> Optional[str]:
+    """全局常量 / 局部常量;不满足则 None。"""
+    is_const = node.get("_is_constant") is True
+    is_local = node.get("_is_local_constant") is True
+    if is_const:
+        return "全局常量"
+    if is_local and not is_const:
+        return "局部常量"
+    return None
+
+
+def _collect_constant_rows(
+    account_name: str,
+) -> list[dict[str, Any]]:
+    """
+    扫描 tree 目录下所有人设树 JSON,收集待写入行(不含 id / create_time)。
+    """
+    td = _tree_dir(account_name)
+    if not td.is_dir():
+        raise FileNotFoundError(f"人设树目录不存在: {td}")
+
+    rows: list[dict[str, Any]] = []
+    seen: set[tuple[Any, ...]] = set()
+
+    def visit(node_name: str, node: dict[str, Any], level: int, source_type: Optional[str]) -> None:
+        kind = _constant_kind(node)
+        if kind is None:
+            pass
+        else:
+            nt = _node_type_raw(node)
+            name = (node_name or "").strip()[:64]
+            key = (name, kind, nt, source_type, level)
+            if key not in seen:
+                seen.add(key)
+                rows.append(
+                    {
+                        "account_name": account_name.strip(),
+                        "constant_node_name": name,
+                        "constant_type": kind,
+                        "node_type": nt or "",
+                        "source_type": source_type,
+                        "tree_level": level,
+                    }
+                )
+        for cname, cnode in (node.get("children") or {}).items():
+            if not isinstance(cnode, dict):
+                continue
+            visit(str(cname).strip(), cnode, level + 1, source_type)
+
+    for path in sorted(td.glob("*.json")):
+        st = _source_type_from_filename(path.stem)
+        try:
+            with open(path, "r", encoding="utf-8") as f:
+                data = json.load(f)
+        except (OSError, json.JSONDecodeError) as e:
+            raise RuntimeError(f"无法读取人设树 JSON: {path}") from e
+        if not isinstance(data, dict):
+            continue
+        for dim_name, root in data.items():
+            if not isinstance(root, dict):
+                continue
+            visit(str(dim_name).strip(), root, 0, st)
+
+    return rows
+
+
+def import_account_constant_to_db(
+    account_name: str,
+    *,
+    replace: bool = True,
+) -> int:
+    """
+    将账号人设树中的全局/局部常量节点写入 open_aigc_pattern.account_constant。
+
+    :param replace: 为 True 时先删除该 account_name 已有记录再插入。
+    :return: 插入行数。
+    """
+    db = get_mysql_db(_DB_SOURCE)
+    rows = _collect_constant_rows(account_name)
+    if not rows:
+        if replace:
+            db.delete(table=_TABLE, where="account_name=%s", where_params=(account_name.strip(),))
+        return 0
+
+    conn = db._client().open_connection()
+    try:
+        conn.begin()
+        if replace:
+            db.delete(
+                table=_TABLE,
+                where="account_name=%s",
+                where_params=(account_name.strip(),),
+                connection=conn,
+            )
+        db.insert_many(_TABLE, rows, connection=conn)
+        conn.commit()
+    except Exception:
+        try:
+            conn.rollback()
+        except Exception:
+            pass
+        raise
+    finally:
+        try:
+            conn.close()
+        except Exception:
+            pass
+
+    return len(rows)
+
+
+def main(account_name) -> None:
+    load_dotenv()
+    n = import_account_constant_to_db(account_name, replace=True)
+    print(f"已写入 {n} 条到 {_TABLE}(account_name={account_name!r})")
+
+
+if __name__ == "__main__":
+    main(account_name="空间点阵设计研究室")