#!/usr/bin/env python # coding=utf-8 import os import time import uuid import threading from concurrent.futures import ThreadPoolExecutor, as_completed from odps import ODPS, options from odps.tunnel import TableTunnel from tqdm import tqdm import pyarrow as pa from pyarrow import csv as pa_csv # DataWorks SDK(可选依赖,仅 DataWorksClient 用到) try: from alibabacloud_dataworks_public20240518.client import Client as _DWClient from alibabacloud_tea_openapi import models as _open_api_models from alibabacloud_dataworks_public20240518 import models as _dw_models _DW_AVAILABLE = True except ImportError: _DW_AVAILABLE = False # 开启 Instance Tunnel,解除 1 万条限制 options.tunnel.use_instance_tunnel = True options.tunnel.limit_instance_tunnel = False # ODPS 配置 ODPS_CONFIGS = { "default": { "access_id": "LTAIWYUujJAm7CbH", "access_secret": "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P", "project": "loghubods", }, "piaoquan_api": { "access_id": "LTAI5tKyXxh7C6349c1wbwUX", "access_secret": "H8doQDC20KugToRA3giERgRyRD1KR9", "project": "piaoquan_api", }, } class ODPSClient(object): def __init__(self, project="loghubods", config="default"): """ 初始化 ODPS 客户端 Args: project: 项目名(可覆盖配置中的默认项目) config: 配置名,可选 "default" 或 "piaoquan_api" """ cfg = ODPS_CONFIGS.get(config, ODPS_CONFIGS["default"]) self.accessId = cfg["access_id"] self.accessSecret = cfg["access_secret"] self.endpoint = "http://service.odps.aliyun.com/api" self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com" # 如果指定了 project 且不是默认值,使用指定的;否则用配置中的 actual_project = project if project != "loghubods" else cfg["project"] self.odps = ODPS( self.accessId, self.accessSecret, actual_project, self.endpoint ) def execute_sql(self, sql: str, print_logview: bool = True): """执行 SQL 并返回 DataFrame""" hints = {'odps.sql.submit.mode': 'script'} instance = self.odps.execute_sql(sql, hints=hints) if print_logview: print(f"LogView: {instance.get_logview_address()}") with instance.open_reader(tunnel=True, limit=False) as reader: pd_df = reader.to_pandas() return pd_df def execute_sql_result_save_file(self, sql: str, output_file: str): """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快)""" hints = {'odps.sql.submit.mode': 'script'} start_time = time.time() instance = self.odps.execute_sql(sql, hints=hints) sql_time = time.time() - start_time print(f"LogView: {instance.get_logview_address()}") print(f"SQL 执行耗时: {sql_time:.1f}s") with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader: total = reader.count # 边下载边写入,用 pyarrow 直接写 CSV with open(output_file, 'wb') as f: first = True with tqdm(total=total, unit='行', desc='下载中') as pbar: for batch in reader: # pyarrow 写 CSV(比 pandas 快很多) options = pa_csv.WriteOptions(include_header=first) pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options) first = False pbar.update(batch.num_rows) total_time = time.time() - start_time print(f"总耗时: {total_time:.1f}s") print(f"完成: {output_file}") def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4): """执行 SQL 并保存到文件(多线程并行下载,速度最快)""" hints = {'odps.sql.submit.mode': 'script'} # 生成临时表名 tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}" create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}" start_time = time.time() # 1. 创建临时表 print(f"创建临时表: {tmp_table}") instance = self.odps.execute_sql(create_sql, hints=hints) print(f"LogView: {instance.get_logview_address()}") instance.wait_for_success() sql_time = time.time() - start_time print(f"SQL 执行耗时: {sql_time:.1f}s") try: # 2. 获取表信息 table = self.odps.get_table(tmp_table) tunnel = TableTunnel(self.odps) download_session = tunnel.create_download_session(table.name) total = download_session.count print(f"总行数: {total}") if total == 0: # 空表,直接写入空 CSV with open(output_file, 'w') as f: columns = [col.name for col in table.table_schema.columns] f.write(','.join(columns) + '\n') print(f"完成: {output_file} (空表)") return # 3. 分段 chunk_size = (total + workers - 1) // workers chunks = [] for i in range(workers): start = i * chunk_size end = min((i + 1) * chunk_size, total) if start < end: chunks.append((i, start, end - start)) # (index, start, count) print(f"并行下载: {len(chunks)} 个分片, {workers} 线程") # 4. 多线程下载到临时文件(放在输出目录) output_dir = os.path.dirname(output_file) tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_") pbar = tqdm(total=total, unit='行', desc='下载中') pbar_lock = threading.Lock() session_id = download_session.id tmp_files = {} def download_chunk(chunk_info): idx, start, count = chunk_info tmp_file = f"{tmp_prefix}{idx:04d}" session = tunnel.create_download_session(table.name, download_id=session_id) with session.open_arrow_reader(start, count) as reader: batches = [] for batch in reader: batches.append(batch) with pbar_lock: pbar.update(batch.num_rows) if batches: tbl = pa.Table.from_batches(batches) pa_csv.write_csv(tbl, tmp_file) return idx, tmp_file # 并行下载 with ThreadPoolExecutor(max_workers=workers) as executor: futures = [executor.submit(download_chunk, chunk) for chunk in chunks] for future in as_completed(futures): idx, tmp_file = future.result() tmp_files[idx] = tmp_file pbar.close() # 按顺序合并 print("合并文件中...") with open(output_file, 'wb') as outf: for idx in range(len(chunks)): tmp_file = tmp_files.get(idx) if tmp_file and os.path.exists(tmp_file): with open(tmp_file, 'rb') as inf: if idx > 0: inf.readline() # 跳过表头 outf.write(inf.read()) os.remove(tmp_file) finally: # 6. 删除临时表 print(f"删除临时表: {tmp_table}") self.odps.delete_table(tmp_table, if_exists=True) total_time = time.time() - start_time print(f"总耗时: {total_time:.1f}s") print(f"完成: {output_file}") # ────────────────────────────────────────────────────────────────────────────── # DataWorks 客户端:根据表名获取生产代码 # ────────────────────────────────────────────────────────────────────────────── # 最佳实践链路:GetTable → GetTask → Script.Content # 1. GetTable(entity_id, include_business_metadata=True) 精确获取表的上游任务 # 2. GetTask(task_id, project_env='Prod') 获取任务的 SQL 代码 # ────────────────────────────────────────────────────────────────────────────── # 账号下所有可访问的 DataWorks 项目(project_id → name) _DW_PROJECTS = { 4858: "loghubods", 11300: "DWH", 5477: "videocdm", 548768: "piaoquan_api", 148813: "content_safety", 96094: "algo", 52578: "majin", 5057: "useractionbi", 5034: "user_video_action_cdm", 4868: "usercdm", 4859: "videoods", 6025: "videoads", 5535: "user_video_tag", 19288: "RecallEmbedding", 10762: "Test_model1", 193831: "cost_mgt_1894469520484605", 156474: "dyp_1", 156475: "dyp_2", 343868: "pq_data_space", 343957: "pq_grafana_se", } _CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "production_code") def _call_with_retry(fn, max_retries=3, base_delay=2): """带限流重试的 API 调用包装。""" for attempt in range(max_retries): try: return fn() except Exception as e: if "Throttling" in str(e) and attempt < max_retries - 1: delay = base_delay * (2 ** attempt) print(f" [throttled] 等待 {delay}s 后重试...") time.sleep(delay) continue raise class DataWorksClient: def __init__(self): if not _DW_AVAILABLE: raise ImportError( "请先安装 DataWorks SDK:\n" "pip install alibabacloud-dataworks-public20240518" ) # 初始化所有 AK 对应的客户端(不同 AK 对不同项目有权限) self._clients = {} for config_name, cfg in ODPS_CONFIGS.items(): dw_config = _open_api_models.Config( access_key_id=cfg["access_id"], access_key_secret=cfg["access_secret"], endpoint="dataworks.cn-hangzhou.aliyuncs.com", ) self._clients[config_name] = _DWClient(dw_config) self.client = self._clients["default"] @staticmethod def _build_entity_id(table_name: str) -> str: """构造 GetTable 的 entity ID。 支持格式: - project.table → maxcompute-table:::project::table - table → maxcompute-table:::loghubods::table """ parts = table_name.split(".", 1) if len(parts) == 2: project, table = parts else: project, table = "loghubods", parts[0] return f"maxcompute-table:::{project}::{table}" def get_table_info(self, table_name: str) -> dict: """获取表的元信息(含上游任务列表)。 Returns: dict with keys: name, comment, dataworks_tasks[{id, name}], ... """ entity_id = self._build_entity_id(table_name) resp = _call_with_retry(lambda: self.client.get_table( _dw_models.GetTableRequest(id=entity_id, include_business_metadata=True) )) table = resp.body.to_map().get("Table", {}) biz = table.get("BusinessMetadata", {}) return { "id": table.get("Id"), "name": table.get("Name"), "comment": table.get("Comment"), "project_id": biz.get("Extension", {}).get("ProjectId"), "dataworks_tasks": biz.get("UpstreamTasks", []), "partition_keys": table.get("PartitionKeys", []), } def _get_task_code(self, task_id: int) -> dict: """尝试用所有 AK 获取任务代码,返回第一个成功的结果。""" for config_name, client in self._clients.items(): try: resp = _call_with_retry(lambda c=client: c.get_task( _dw_models.GetTaskRequest(id=task_id, project_env="Prod") )) task = resp.body.to_map().get("Task", {}) return { "task_id": task_id, "task_name": task.get("Name"), "task_type": task.get("Type"), "content": (task.get("Script") or {}).get("Content", ""), "config": config_name, } except Exception as e: if "11020205003" in str(e): continue # 无权限,尝试下一个 AK raise return None @staticmethod def _normalize_table_name(table_name: str) -> str: """补全 project 前缀:table → loghubods.table""" if "." not in table_name: return f"loghubods.{table_name}" return table_name @staticmethod def _cache_path(table_name: str) -> str: return os.path.join(_CACHE_DIR, f"{table_name}.sql") @staticmethod def _schema_cache_path(table_name: str) -> str: return os.path.join(_CACHE_DIR, f"{table_name}.json") def _read_cache(self, table_name: str) -> str | None: path = self._cache_path(table_name) if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: return f.read() return None def _write_cache(self, table_name: str, content: str): os.makedirs(_CACHE_DIR, exist_ok=True) with open(self._cache_path(table_name), "w", encoding="utf-8") as f: f.write(content) def _read_schema_cache(self, table_name: str) -> dict | None: import json path = self._schema_cache_path(table_name) if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: return json.load(f) return None def _write_schema_cache(self, table_name: str, schema: dict): import json os.makedirs(_CACHE_DIR, exist_ok=True) with open(self._schema_cache_path(table_name), "w", encoding="utf-8") as f: json.dump(schema, f, ensure_ascii=False, indent=2) def _ensure_schema_cache(self, table_name: str, force: bool = False, dataworks_tasks: list | None = None): """确保 schema 缓存存在,无则拉取并写入。 Args: dataworks_tasks: 预获取的上游任务列表,避免重复调用 get_table_info() """ if not force: cached = self._read_schema_cache(table_name) if cached is not None: return try: schema = self.get_table_schema(table_name, dataworks_tasks=dataworks_tasks) self._write_schema_cache(table_name, schema) print(f"[saved] {self._schema_cache_path(table_name)}") except Exception as e: print(f"[WARN] 获取表结构失败 {table_name}: {e}") def get_table_schema(self, table_name: str, dataworks_tasks: list | None = None) -> dict: """通过 ODPS SDK 获取表结构元信息。 Args: table_name: 表名(支持 project.table 格式) dataworks_tasks: 预获取的上游任务列表,避免重复 API 调用 Returns: dict: {name, project, comment, columns, partition_keys, dataworks_tasks} """ table_name = self._normalize_table_name(table_name) parts = table_name.split(".", 1) project, table = parts[0], parts[1] # 用默认 AK 对应的 ODPSClient 获取 ODPS 表结构 odps_client = ODPSClient(project=project) t = odps_client.odps.get_table(table) columns = [ {"name": c.name, "type": str(c.type), "comment": c.comment or ""} for c in t.table_schema.columns ] partition_keys = [ {"name": c.name, "type": str(c.type), "comment": c.comment or ""} for c in t.table_schema.partitions ] # 上游任务:优先用传入的,否则从 DataWorks API 获取 if dataworks_tasks is None: try: info = self.get_table_info(table_name) dataworks_tasks = [ {"id": task.get("Id"), "name": task.get("Name")} for task in info.get("dataworks_tasks", []) ] except Exception: dataworks_tasks = [] # 直接上游表(血缘) try: upstream_tables = self.get_upstream_tables(table_name) except Exception: upstream_tables = [] return { "name": table, "project": project, "comment": t.comment or "", "columns": columns, "partition_keys": partition_keys, "dataworks_tasks": dataworks_tasks, "upstream_tables": upstream_tables, } def get_node_code(self, table_name: str, force: bool = False) -> list: """根据表名获取生产代码(优先读本地缓存)。 流程:本地缓存 → GetTable → GetTask → 写缓存 → 返回代码 Args: table_name: 表名(支持 project.table 格式) force: True 时跳过缓存,强制从 API 拉取 Returns: list of dict,每条包含: task_id, task_name, task_type, content """ table_name = self._normalize_table_name(table_name) # 读缓存 if not force: cached = self._read_cache(table_name) if cached is not None: print(f"[cache] {self._cache_path(table_name)}") # 同时检查 schema 缓存,无则补拉 self._ensure_schema_cache(table_name, force=False) return [{"task_id": None, "task_name": "(cached)", "task_type": None, "content": cached}] # API 拉取 info = self.get_table_info(table_name) upstream = info.get("dataworks_tasks", []) if not upstream: print(f"表 '{table_name}' 没有上游任务") return [] results = [] for task in upstream: task_id = task.get("Id") task_name = task.get("Name") result = self._get_task_code(task_id) if result: results.append(result) else: print(f"[WARN] 任务 {task_name}({task_id}) 所有 AK 均无权限") # 写缓存 if results: parts = [] for r in results: header = f"-- Task: {r['task_name']} ID: {r['task_id']} Type: {r['task_type']}" parts.append(f"{header}\n{r['content']}") self._write_cache(table_name, "\n\n".join(parts)) print(f"[saved] {self._cache_path(table_name)}") # 获取并缓存 schema(复用已有的上游任务信息,避免重复 API 调用) up_tasks = [ {"id": task.get("Id"), "name": task.get("Name")} for task in upstream ] self._ensure_schema_cache(table_name, force=force, dataworks_tasks=up_tasks) return results def get_upstream_tables(self, table_name: str) -> list[str]: """通过血缘 API 获取表的直接上游表列表。 Returns: list of str,如 ["loghubods.user_share_log_flow", ...] """ entity_id = self._build_entity_id(table_name) resp = _call_with_retry(lambda: self.client.list_lineages( _dw_models.ListLineagesRequest(dst_entity_id=entity_id, page_size=50) )) lineages = resp.body.to_map().get("PagingInfo", {}).get("Lineages", []) tables = [] for l in lineages: src_id = l.get("SrcEntity", {}).get("Id", "") # maxcompute-table:::project::table → project.table parts = src_id.replace("maxcompute-table:::", "").split("::") if len(parts) == 2: tables.append(f"{parts[0]}.{parts[1]}") return sorted(set(tables)) def get_node_code_recursive(self, table_name: str, max_depth: int = 3, force: bool = False) -> dict: """BFS 逐层获取表及其所有上游表的生产代码。 通过血缘 API(ListLineages)逐层追溯上游依赖, 每层的代码和上游表都会被缓存到 production_code/。 Args: table_name: 表名(支持 project.table 格式) max_depth: 最大追溯层数,默认 3 force: True 时跳过缓存 Returns: dict: { "project.table": { "code": [...], # get_node_code 返回值 "upstream": ["a.b", ...], # 上游表名列表 "depth": int }, ... } """ from collections import deque table_name = self._normalize_table_name(table_name) result = {} queue = deque([(table_name, 0)]) visited = {table_name} while queue: tbl, depth = queue.popleft() indent = " " * depth print(f"{indent}[depth={depth}] {tbl}") # 获取代码 code = self.get_node_code(tbl, force=force) # 获取上游表 upstream = [] if depth < max_depth: try: upstream = self.get_upstream_tables(tbl) except Exception: pass result[tbl] = {"code": code, "upstream": upstream, "depth": depth} # 下一层入队 for up_tbl in upstream: if up_tbl not in visited: visited.add(up_tbl) queue.append((up_tbl, depth + 1)) # 打印汇总 print(f"\n共追溯 {len(result)} 张表:") for tbl, info in result.items(): has_code = "有代码" if info["code"] else "无代码" n_up = len(info["upstream"]) print(f" {' ' * info['depth']}{tbl} ({has_code}, {n_up} 个上游)") return result def print_node_code(self, table_name: str): """打印表的生产代码(人类可读格式)""" results = self.get_node_code(table_name) if not results: print(f"未找到 '{table_name}' 的生产代码") return for r in results: print(f"\n{'='*60}") print(f"任务: {r['task_name']} ID: {r['task_id']} " f"类型: {r['task_type']}") print(f"{'='*60}") print(r["content"] or "(无内容)")