| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600 |
- #!/usr/bin/env python
- # coding=utf-8
- import os
- import time
- import uuid
- import threading
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from odps import ODPS, options
- from odps.tunnel import TableTunnel
- from tqdm import tqdm
- import pyarrow as pa
- from pyarrow import csv as pa_csv
- # DataWorks SDK(可选依赖,仅 DataWorksClient 用到)
- try:
- from alibabacloud_dataworks_public20240518.client import Client as _DWClient
- from alibabacloud_tea_openapi import models as _open_api_models
- from alibabacloud_dataworks_public20240518 import models as _dw_models
- _DW_AVAILABLE = True
- except ImportError:
- _DW_AVAILABLE = False
- # 开启 Instance Tunnel,解除 1 万条限制
- options.tunnel.use_instance_tunnel = True
- options.tunnel.limit_instance_tunnel = False
- # ODPS 配置
- ODPS_CONFIGS = {
- "default": {
- "access_id": "LTAIWYUujJAm7CbH",
- "access_secret": "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P",
- "project": "loghubods",
- },
- "piaoquan_api": {
- "access_id": "LTAI5tKyXxh7C6349c1wbwUX",
- "access_secret": "H8doQDC20KugToRA3giERgRyRD1KR9",
- "project": "piaoquan_api",
- },
- }
- class ODPSClient(object):
- def __init__(self, project="loghubods", config="default"):
- """
- 初始化 ODPS 客户端
- Args:
- project: 项目名(可覆盖配置中的默认项目)
- config: 配置名,可选 "default" 或 "piaoquan_api"
- """
- cfg = ODPS_CONFIGS.get(config, ODPS_CONFIGS["default"])
- self.accessId = cfg["access_id"]
- self.accessSecret = cfg["access_secret"]
- self.endpoint = "http://service.odps.aliyun.com/api"
- self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com"
- # 如果指定了 project 且不是默认值,使用指定的;否则用配置中的
- actual_project = project if project != "loghubods" else cfg["project"]
- self.odps = ODPS(
- self.accessId,
- self.accessSecret,
- actual_project,
- self.endpoint
- )
- def execute_sql(self, sql: str, print_logview: bool = True):
- """执行 SQL 并返回 DataFrame"""
- hints = {'odps.sql.submit.mode': 'script'}
- instance = self.odps.execute_sql(sql, hints=hints)
- if print_logview:
- print(f"LogView: {instance.get_logview_address()}")
- with instance.open_reader(tunnel=True, limit=False) as reader:
- pd_df = reader.to_pandas()
- return pd_df
- def execute_sql_result_save_file(self, sql: str, output_file: str):
- """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快)"""
- hints = {'odps.sql.submit.mode': 'script'}
- start_time = time.time()
- instance = self.odps.execute_sql(sql, hints=hints)
- sql_time = time.time() - start_time
- print(f"LogView: {instance.get_logview_address()}")
- print(f"SQL 执行耗时: {sql_time:.1f}s")
- with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader:
- total = reader.count
- # 边下载边写入,用 pyarrow 直接写 CSV
- with open(output_file, 'wb') as f:
- first = True
- with tqdm(total=total, unit='行', desc='下载中') as pbar:
- for batch in reader:
- # pyarrow 写 CSV(比 pandas 快很多)
- options = pa_csv.WriteOptions(include_header=first)
- pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options)
- first = False
- pbar.update(batch.num_rows)
- total_time = time.time() - start_time
- print(f"总耗时: {total_time:.1f}s")
- print(f"完成: {output_file}")
- def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4):
- """执行 SQL 并保存到文件(多线程并行下载,速度最快)"""
- hints = {'odps.sql.submit.mode': 'script'}
- # 生成临时表名
- tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}"
- create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}"
- start_time = time.time()
- # 1. 创建临时表
- print(f"创建临时表: {tmp_table}")
- instance = self.odps.execute_sql(create_sql, hints=hints)
- print(f"LogView: {instance.get_logview_address()}")
- instance.wait_for_success()
- sql_time = time.time() - start_time
- print(f"SQL 执行耗时: {sql_time:.1f}s")
- try:
- # 2. 获取表信息
- table = self.odps.get_table(tmp_table)
- tunnel = TableTunnel(self.odps)
- download_session = tunnel.create_download_session(table.name)
- total = download_session.count
- print(f"总行数: {total}")
- if total == 0:
- # 空表,直接写入空 CSV
- with open(output_file, 'w') as f:
- columns = [col.name for col in table.table_schema.columns]
- f.write(','.join(columns) + '\n')
- print(f"完成: {output_file} (空表)")
- return
- # 3. 分段
- chunk_size = (total + workers - 1) // workers
- chunks = []
- for i in range(workers):
- start = i * chunk_size
- end = min((i + 1) * chunk_size, total)
- if start < end:
- chunks.append((i, start, end - start)) # (index, start, count)
- print(f"并行下载: {len(chunks)} 个分片, {workers} 线程")
- # 4. 多线程下载到临时文件(放在输出目录)
- output_dir = os.path.dirname(output_file)
- tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_")
- pbar = tqdm(total=total, unit='行', desc='下载中')
- pbar_lock = threading.Lock()
- session_id = download_session.id
- tmp_files = {}
- def download_chunk(chunk_info):
- idx, start, count = chunk_info
- tmp_file = f"{tmp_prefix}{idx:04d}"
- session = tunnel.create_download_session(table.name, download_id=session_id)
- with session.open_arrow_reader(start, count) as reader:
- batches = []
- for batch in reader:
- batches.append(batch)
- with pbar_lock:
- pbar.update(batch.num_rows)
- if batches:
- tbl = pa.Table.from_batches(batches)
- pa_csv.write_csv(tbl, tmp_file)
- return idx, tmp_file
- # 并行下载
- with ThreadPoolExecutor(max_workers=workers) as executor:
- futures = [executor.submit(download_chunk, chunk) for chunk in chunks]
- for future in as_completed(futures):
- idx, tmp_file = future.result()
- tmp_files[idx] = tmp_file
- pbar.close()
- # 按顺序合并
- print("合并文件中...")
- with open(output_file, 'wb') as outf:
- for idx in range(len(chunks)):
- tmp_file = tmp_files.get(idx)
- if tmp_file and os.path.exists(tmp_file):
- with open(tmp_file, 'rb') as inf:
- if idx > 0:
- inf.readline() # 跳过表头
- outf.write(inf.read())
- os.remove(tmp_file)
- finally:
- # 6. 删除临时表
- print(f"删除临时表: {tmp_table}")
- self.odps.delete_table(tmp_table, if_exists=True)
- total_time = time.time() - start_time
- print(f"总耗时: {total_time:.1f}s")
- print(f"完成: {output_file}")
- # ──────────────────────────────────────────────────────────────────────────────
- # DataWorks 客户端:根据表名获取生产代码
- # ──────────────────────────────────────────────────────────────────────────────
- # 最佳实践链路:GetTable → GetTask → Script.Content
- # 1. GetTable(entity_id, include_business_metadata=True) 精确获取表的上游任务
- # 2. GetTask(task_id, project_env='Prod') 获取任务的 SQL 代码
- # ──────────────────────────────────────────────────────────────────────────────
- # 账号下所有可访问的 DataWorks 项目(project_id → name)
- _DW_PROJECTS = {
- 4858: "loghubods",
- 11300: "DWH",
- 5477: "videocdm",
- 548768: "piaoquan_api",
- 148813: "content_safety",
- 96094: "algo",
- 52578: "majin",
- 5057: "useractionbi",
- 5034: "user_video_action_cdm",
- 4868: "usercdm",
- 4859: "videoods",
- 6025: "videoads",
- 5535: "user_video_tag",
- 19288: "RecallEmbedding",
- 10762: "Test_model1",
- 193831: "cost_mgt_1894469520484605",
- 156474: "dyp_1",
- 156475: "dyp_2",
- 343868: "pq_data_space",
- 343957: "pq_grafana_se",
- }
- _CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "production_code")
- def _call_with_retry(fn, max_retries=3, base_delay=2):
- """带限流重试的 API 调用包装。"""
- for attempt in range(max_retries):
- try:
- return fn()
- except Exception as e:
- if "Throttling" in str(e) and attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f" [throttled] 等待 {delay}s 后重试...")
- time.sleep(delay)
- continue
- raise
- class DataWorksClient:
- def __init__(self):
- if not _DW_AVAILABLE:
- raise ImportError(
- "请先安装 DataWorks SDK:\n"
- "pip install alibabacloud-dataworks-public20240518"
- )
- # 初始化所有 AK 对应的客户端(不同 AK 对不同项目有权限)
- self._clients = {}
- for config_name, cfg in ODPS_CONFIGS.items():
- dw_config = _open_api_models.Config(
- access_key_id=cfg["access_id"],
- access_key_secret=cfg["access_secret"],
- endpoint="dataworks.cn-hangzhou.aliyuncs.com",
- )
- self._clients[config_name] = _DWClient(dw_config)
- self.client = self._clients["default"]
- @staticmethod
- def _build_entity_id(table_name: str) -> str:
- """构造 GetTable 的 entity ID。
- 支持格式:
- - project.table → maxcompute-table:::project::table
- - table → maxcompute-table:::loghubods::table
- """
- parts = table_name.split(".", 1)
- if len(parts) == 2:
- project, table = parts
- else:
- project, table = "loghubods", parts[0]
- return f"maxcompute-table:::{project}::{table}"
- def get_table_info(self, table_name: str) -> dict:
- """获取表的元信息(含上游任务列表)。
- Returns:
- dict with keys: name, comment, dataworks_tasks[{id, name}], ...
- """
- entity_id = self._build_entity_id(table_name)
- resp = _call_with_retry(lambda: self.client.get_table(
- _dw_models.GetTableRequest(id=entity_id, include_business_metadata=True)
- ))
- table = resp.body.to_map().get("Table", {})
- biz = table.get("BusinessMetadata", {})
- return {
- "id": table.get("Id"),
- "name": table.get("Name"),
- "comment": table.get("Comment"),
- "project_id": biz.get("Extension", {}).get("ProjectId"),
- "dataworks_tasks": biz.get("UpstreamTasks", []),
- "partition_keys": table.get("PartitionKeys", []),
- }
- def _get_task_code(self, task_id: int) -> dict:
- """尝试用所有 AK 获取任务代码,返回第一个成功的结果。"""
- for config_name, client in self._clients.items():
- try:
- resp = _call_with_retry(lambda c=client: c.get_task(
- _dw_models.GetTaskRequest(id=task_id, project_env="Prod")
- ))
- task = resp.body.to_map().get("Task", {})
- return {
- "task_id": task_id,
- "task_name": task.get("Name"),
- "task_type": task.get("Type"),
- "content": (task.get("Script") or {}).get("Content", ""),
- "config": config_name,
- }
- except Exception as e:
- if "11020205003" in str(e):
- continue # 无权限,尝试下一个 AK
- raise
- return None
- @staticmethod
- def _normalize_table_name(table_name: str) -> str:
- """补全 project 前缀:table → loghubods.table"""
- if "." not in table_name:
- return f"loghubods.{table_name}"
- return table_name
- @staticmethod
- def _cache_path(table_name: str) -> str:
- return os.path.join(_CACHE_DIR, f"{table_name}.sql")
- @staticmethod
- def _schema_cache_path(table_name: str) -> str:
- return os.path.join(_CACHE_DIR, f"{table_name}.json")
- def _read_cache(self, table_name: str) -> str | None:
- path = self._cache_path(table_name)
- if os.path.exists(path):
- with open(path, "r", encoding="utf-8") as f:
- return f.read()
- return None
- def _write_cache(self, table_name: str, content: str):
- os.makedirs(_CACHE_DIR, exist_ok=True)
- with open(self._cache_path(table_name), "w", encoding="utf-8") as f:
- f.write(content)
- def _read_schema_cache(self, table_name: str) -> dict | None:
- import json
- path = self._schema_cache_path(table_name)
- if os.path.exists(path):
- with open(path, "r", encoding="utf-8") as f:
- return json.load(f)
- return None
- def _write_schema_cache(self, table_name: str, schema: dict):
- import json
- os.makedirs(_CACHE_DIR, exist_ok=True)
- with open(self._schema_cache_path(table_name), "w", encoding="utf-8") as f:
- json.dump(schema, f, ensure_ascii=False, indent=2)
- def _ensure_schema_cache(self, table_name: str, force: bool = False,
- dataworks_tasks: list | None = None):
- """确保 schema 缓存存在,无则拉取并写入。
- Args:
- dataworks_tasks: 预获取的上游任务列表,避免重复调用 get_table_info()
- """
- if not force:
- cached = self._read_schema_cache(table_name)
- if cached is not None:
- return
- try:
- schema = self.get_table_schema(table_name, dataworks_tasks=dataworks_tasks)
- self._write_schema_cache(table_name, schema)
- print(f"[saved] {self._schema_cache_path(table_name)}")
- except Exception as e:
- print(f"[WARN] 获取表结构失败 {table_name}: {e}")
- def get_table_schema(self, table_name: str,
- dataworks_tasks: list | None = None) -> dict:
- """通过 ODPS SDK 获取表结构元信息。
- Args:
- table_name: 表名(支持 project.table 格式)
- dataworks_tasks: 预获取的上游任务列表,避免重复 API 调用
- Returns:
- dict: {name, project, comment, columns, partition_keys, dataworks_tasks}
- """
- table_name = self._normalize_table_name(table_name)
- parts = table_name.split(".", 1)
- project, table = parts[0], parts[1]
- # 用默认 AK 对应的 ODPSClient 获取 ODPS 表结构
- odps_client = ODPSClient(project=project)
- t = odps_client.odps.get_table(table)
- columns = [
- {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
- for c in t.table_schema.columns
- ]
- partition_keys = [
- {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
- for c in t.table_schema.partitions
- ]
- # 上游任务:优先用传入的,否则从 DataWorks API 获取
- if dataworks_tasks is None:
- try:
- info = self.get_table_info(table_name)
- dataworks_tasks = [
- {"id": task.get("Id"), "name": task.get("Name")}
- for task in info.get("dataworks_tasks", [])
- ]
- except Exception:
- dataworks_tasks = []
- # 直接上游表(血缘)
- try:
- upstream_tables = self.get_upstream_tables(table_name)
- except Exception:
- upstream_tables = []
- return {
- "name": table,
- "project": project,
- "comment": t.comment or "",
- "columns": columns,
- "partition_keys": partition_keys,
- "dataworks_tasks": dataworks_tasks,
- "upstream_tables": upstream_tables,
- }
- def get_node_code(self, table_name: str, force: bool = False) -> list:
- """根据表名获取生产代码(优先读本地缓存)。
- 流程:本地缓存 → GetTable → GetTask → 写缓存 → 返回代码
- Args:
- table_name: 表名(支持 project.table 格式)
- force: True 时跳过缓存,强制从 API 拉取
- Returns:
- list of dict,每条包含:
- task_id, task_name, task_type, content
- """
- table_name = self._normalize_table_name(table_name)
- # 读缓存
- if not force:
- cached = self._read_cache(table_name)
- if cached is not None:
- print(f"[cache] {self._cache_path(table_name)}")
- # 同时检查 schema 缓存,无则补拉
- self._ensure_schema_cache(table_name, force=False)
- return [{"task_id": None, "task_name": "(cached)", "task_type": None, "content": cached}]
- # API 拉取
- info = self.get_table_info(table_name)
- upstream = info.get("dataworks_tasks", [])
- if not upstream:
- print(f"表 '{table_name}' 没有上游任务")
- return []
- results = []
- for task in upstream:
- task_id = task.get("Id")
- task_name = task.get("Name")
- result = self._get_task_code(task_id)
- if result:
- results.append(result)
- else:
- print(f"[WARN] 任务 {task_name}({task_id}) 所有 AK 均无权限")
- # 写缓存
- if results:
- parts = []
- for r in results:
- header = f"-- Task: {r['task_name']} ID: {r['task_id']} Type: {r['task_type']}"
- parts.append(f"{header}\n{r['content']}")
- self._write_cache(table_name, "\n\n".join(parts))
- print(f"[saved] {self._cache_path(table_name)}")
- # 获取并缓存 schema(复用已有的上游任务信息,避免重复 API 调用)
- up_tasks = [
- {"id": task.get("Id"), "name": task.get("Name")}
- for task in upstream
- ]
- self._ensure_schema_cache(table_name, force=force, dataworks_tasks=up_tasks)
- return results
- def get_upstream_tables(self, table_name: str) -> list[str]:
- """通过血缘 API 获取表的直接上游表列表。
- Returns:
- list of str,如 ["loghubods.user_share_log_flow", ...]
- """
- entity_id = self._build_entity_id(table_name)
- resp = _call_with_retry(lambda: self.client.list_lineages(
- _dw_models.ListLineagesRequest(dst_entity_id=entity_id, page_size=50)
- ))
- lineages = resp.body.to_map().get("PagingInfo", {}).get("Lineages", [])
- tables = []
- for l in lineages:
- src_id = l.get("SrcEntity", {}).get("Id", "")
- # maxcompute-table:::project::table → project.table
- parts = src_id.replace("maxcompute-table:::", "").split("::")
- if len(parts) == 2:
- tables.append(f"{parts[0]}.{parts[1]}")
- return sorted(set(tables))
- def get_node_code_recursive(self, table_name: str, max_depth: int = 3,
- force: bool = False) -> dict:
- """BFS 逐层获取表及其所有上游表的生产代码。
- 通过血缘 API(ListLineages)逐层追溯上游依赖,
- 每层的代码和上游表都会被缓存到 production_code/。
- Args:
- table_name: 表名(支持 project.table 格式)
- max_depth: 最大追溯层数,默认 3
- force: True 时跳过缓存
- Returns:
- dict: {
- "project.table": {
- "code": [...], # get_node_code 返回值
- "upstream": ["a.b", ...], # 上游表名列表
- "depth": int
- }, ...
- }
- """
- from collections import deque
- table_name = self._normalize_table_name(table_name)
- result = {}
- queue = deque([(table_name, 0)])
- visited = {table_name}
- while queue:
- tbl, depth = queue.popleft()
- indent = " " * depth
- print(f"{indent}[depth={depth}] {tbl}")
- # 获取代码
- code = self.get_node_code(tbl, force=force)
- # 获取上游表
- upstream = []
- if depth < max_depth:
- try:
- upstream = self.get_upstream_tables(tbl)
- except Exception:
- pass
- result[tbl] = {"code": code, "upstream": upstream, "depth": depth}
- # 下一层入队
- for up_tbl in upstream:
- if up_tbl not in visited:
- visited.add(up_tbl)
- queue.append((up_tbl, depth + 1))
- # 打印汇总
- print(f"\n共追溯 {len(result)} 张表:")
- for tbl, info in result.items():
- has_code = "有代码" if info["code"] else "无代码"
- n_up = len(info["upstream"])
- print(f" {' ' * info['depth']}{tbl} ({has_code}, {n_up} 个上游)")
- return result
- def print_node_code(self, table_name: str):
- """打印表的生产代码(人类可读格式)"""
- results = self.get_node_code(table_name)
- if not results:
- print(f"未找到 '{table_name}' 的生产代码")
- return
- for r in results:
- print(f"\n{'='*60}")
- print(f"任务: {r['task_name']} ID: {r['task_id']} "
- f"类型: {r['task_type']}")
- print(f"{'='*60}")
- print(r["content"] or "(无内容)")
|