|
|
@@ -232,18 +232,21 @@ class DatabaseRuntimeStore:
|
|
|
|
|
|
def append_jsonl(self, run_id: str, filename: str, rows: list[dict[str, Any]]) -> Path:
|
|
|
table = _table_for_runtime_file(filename)
|
|
|
+ # 整批共用一个连接、一次 commit:避免每行新建连接+commit 的 N 次网络往返。
|
|
|
+ statements: list[tuple[str, list[Any]]] = []
|
|
|
for row in rows:
|
|
|
if row.get("run_id") != run_id:
|
|
|
raise ValueError(f"{filename} row run_id does not match runtime run_id")
|
|
|
record = _record_for_jsonl(filename, row)
|
|
|
- if filename in JSONL_UPSERT_KEYS:
|
|
|
- self._upsert(
|
|
|
- table,
|
|
|
- record,
|
|
|
- key_columns=JSONL_UPSERT_KEYS[filename],
|
|
|
- )
|
|
|
- else:
|
|
|
- self._insert(table, record)
|
|
|
+ statements.append(
|
|
|
+ self._row_sql(table, record, key_columns=JSONL_UPSERT_KEYS.get(filename))
|
|
|
+ )
|
|
|
+ if statements:
|
|
|
+ with self._connection_factory() as conn:
|
|
|
+ with conn.cursor() as cur:
|
|
|
+ for sql, values in statements:
|
|
|
+ cur.execute(sql, values)
|
|
|
+ conn.commit()
|
|
|
return self.run_dir(run_id) / filename
|
|
|
|
|
|
def read_json(self, run_id: str, filename: str) -> dict[str, Any]:
|
|
|
@@ -376,7 +379,12 @@ class DatabaseRuntimeStore:
|
|
|
key_columns=("run_id", "policy_run_id", "clue_id"),
|
|
|
)
|
|
|
|
|
|
- def _insert(self, table: str, record: dict[str, Any]) -> None:
|
|
|
+ def _row_sql(
|
|
|
+ self,
|
|
|
+ table: str,
|
|
|
+ record: dict[str, Any],
|
|
|
+ key_columns: tuple[str, ...] | None = None,
|
|
|
+ ) -> tuple[str, list[Any]]:
|
|
|
sanitized = _sanitize_record(table, record)
|
|
|
columns = list(sanitized)
|
|
|
placeholders = ", ".join(["%s"] * len(columns))
|
|
|
@@ -385,12 +393,19 @@ class DatabaseRuntimeStore:
|
|
|
_db_value(table, column, sanitized[column])
|
|
|
for column in columns
|
|
|
]
|
|
|
+ sql = f"INSERT INTO `{table}` ({column_sql}) VALUES ({placeholders})"
|
|
|
+ if key_columns:
|
|
|
+ update_columns = [column for column in columns if column not in key_columns]
|
|
|
+ assignments = ", ".join(f"`{column}` = VALUES(`{column}`)" for column in update_columns)
|
|
|
+ if assignments:
|
|
|
+ sql += f" ON DUPLICATE KEY UPDATE {assignments}"
|
|
|
+ return sql, values
|
|
|
+
|
|
|
+ def _insert(self, table: str, record: dict[str, Any]) -> None:
|
|
|
+ sql, values = self._row_sql(table, record)
|
|
|
with self._connection_factory() as conn:
|
|
|
with conn.cursor() as cur:
|
|
|
- cur.execute(
|
|
|
- f"INSERT INTO `{table}` ({column_sql}) VALUES ({placeholders})",
|
|
|
- values,
|
|
|
- )
|
|
|
+ cur.execute(sql, values)
|
|
|
conn.commit()
|
|
|
|
|
|
def _upsert(
|
|
|
@@ -399,19 +414,7 @@ class DatabaseRuntimeStore:
|
|
|
record: dict[str, Any],
|
|
|
key_columns: tuple[str, ...],
|
|
|
) -> None:
|
|
|
- sanitized = _sanitize_record(table, record)
|
|
|
- columns = list(sanitized)
|
|
|
- placeholders = ", ".join(["%s"] * len(columns))
|
|
|
- column_sql = ", ".join(f"`{column}`" for column in columns)
|
|
|
- update_columns = [column for column in columns if column not in key_columns]
|
|
|
- assignments = ", ".join(f"`{column}` = VALUES(`{column}`)" for column in update_columns)
|
|
|
- values = [
|
|
|
- _db_value(table, column, sanitized[column])
|
|
|
- for column in columns
|
|
|
- ]
|
|
|
- sql = f"INSERT INTO `{table}` ({column_sql}) VALUES ({placeholders})"
|
|
|
- if assignments:
|
|
|
- sql += f" ON DUPLICATE KEY UPDATE {assignments}"
|
|
|
+ sql, values = self._row_sql(table, record, key_columns=key_columns)
|
|
|
with self._connection_factory() as conn:
|
|
|
with conn.cursor() as cur:
|
|
|
cur.execute(sql, values)
|