odps_module.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. import os
  4. import time
  5. import uuid
  6. import threading
  7. from concurrent.futures import ThreadPoolExecutor, as_completed
  8. from odps import ODPS, options
  9. from odps.tunnel import TableTunnel
  10. from tqdm import tqdm
  11. import pyarrow as pa
  12. from pyarrow import csv as pa_csv
  13. # DataWorks SDK(可选依赖,仅 DataWorksClient 用到)
  14. try:
  15. from alibabacloud_dataworks_public20240518.client import Client as _DWClient
  16. from alibabacloud_tea_openapi import models as _open_api_models
  17. from alibabacloud_dataworks_public20240518 import models as _dw_models
  18. _DW_AVAILABLE = True
  19. except ImportError:
  20. _DW_AVAILABLE = False
  21. # 开启 Instance Tunnel,解除 1 万条限制
  22. options.tunnel.use_instance_tunnel = True
  23. options.tunnel.limit_instance_tunnel = False
  24. # ODPS 配置
  25. ODPS_CONFIGS = {
  26. "default": {
  27. "access_id": "LTAIWYUujJAm7CbH",
  28. "access_secret": "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P",
  29. "project": "loghubods",
  30. },
  31. "piaoquan_api": {
  32. "access_id": "LTAI5tKyXxh7C6349c1wbwUX",
  33. "access_secret": "H8doQDC20KugToRA3giERgRyRD1KR9",
  34. "project": "piaoquan_api",
  35. },
  36. }
  37. class ODPSClient(object):
  38. def __init__(self, project="loghubods", config="default"):
  39. """
  40. 初始化 ODPS 客户端
  41. Args:
  42. project: 项目名(可覆盖配置中的默认项目)
  43. config: 配置名,可选 "default" 或 "piaoquan_api"
  44. """
  45. cfg = ODPS_CONFIGS.get(config, ODPS_CONFIGS["default"])
  46. self.accessId = cfg["access_id"]
  47. self.accessSecret = cfg["access_secret"]
  48. self.endpoint = "http://service.odps.aliyun.com/api"
  49. self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com"
  50. # 如果指定了 project 且不是默认值,使用指定的;否则用配置中的
  51. actual_project = project if project != "loghubods" else cfg["project"]
  52. self.odps = ODPS(
  53. self.accessId,
  54. self.accessSecret,
  55. actual_project,
  56. self.endpoint
  57. )
  58. def execute_sql(self, sql: str, print_logview: bool = True):
  59. """执行 SQL 并返回 DataFrame(异步提交,LogView 秒出)"""
  60. hints = {'odps.sql.submit.mode': 'script'}
  61. instance = self.odps.run_sql(sql, hints=hints)
  62. if print_logview:
  63. print(f"LogView: {instance.get_logview_address()}")
  64. instance.wait_for_success()
  65. with instance.open_reader(tunnel=True, limit=False) as reader:
  66. pd_df = reader.to_pandas()
  67. return pd_df
  68. def execute_sql_result_save_file(self, sql: str, output_file: str):
  69. """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快;异步提交,LogView 秒出)"""
  70. hints = {'odps.sql.submit.mode': 'script'}
  71. start_time = time.time()
  72. instance = self.odps.run_sql(sql, hints=hints)
  73. print(f"LogView: {instance.get_logview_address()}")
  74. instance.wait_for_success()
  75. sql_time = time.time() - start_time
  76. print(f"SQL 执行耗时: {sql_time:.1f}s")
  77. with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader:
  78. total = reader.count
  79. # 边下载边写入,用 pyarrow 直接写 CSV
  80. with open(output_file, 'wb') as f:
  81. first = True
  82. with tqdm(total=total, unit='行', desc='下载中') as pbar:
  83. for batch in reader:
  84. # pyarrow 写 CSV(比 pandas 快很多)
  85. options = pa_csv.WriteOptions(include_header=first)
  86. pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options)
  87. first = False
  88. pbar.update(batch.num_rows)
  89. total_time = time.time() - start_time
  90. print(f"总耗时: {total_time:.1f}s")
  91. print(f"完成: {output_file}")
  92. def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4):
  93. """执行 SQL 并保存到文件(多线程并行下载,速度最快)"""
  94. hints = {'odps.sql.submit.mode': 'script'}
  95. # 生成临时表名
  96. tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}"
  97. create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}"
  98. start_time = time.time()
  99. # 1. 创建临时表(异步提交,LogView 秒出)
  100. print(f"创建临时表: {tmp_table}")
  101. instance = self.odps.run_sql(create_sql, hints=hints)
  102. print(f"LogView: {instance.get_logview_address()}")
  103. instance.wait_for_success()
  104. sql_time = time.time() - start_time
  105. print(f"SQL 执行耗时: {sql_time:.1f}s")
  106. try:
  107. # 2. 获取表信息
  108. table = self.odps.get_table(tmp_table)
  109. tunnel = TableTunnel(self.odps)
  110. download_session = tunnel.create_download_session(table.name)
  111. total = download_session.count
  112. print(f"总行数: {total}")
  113. if total == 0:
  114. # 空表,直接写入空 CSV
  115. with open(output_file, 'w') as f:
  116. columns = [col.name for col in table.table_schema.columns]
  117. f.write(','.join(columns) + '\n')
  118. print(f"完成: {output_file} (空表)")
  119. return
  120. # 3. 分段
  121. chunk_size = (total + workers - 1) // workers
  122. chunks = []
  123. for i in range(workers):
  124. start = i * chunk_size
  125. end = min((i + 1) * chunk_size, total)
  126. if start < end:
  127. chunks.append((i, start, end - start)) # (index, start, count)
  128. print(f"并行下载: {len(chunks)} 个分片, {workers} 线程")
  129. # 4. 多线程下载到临时文件(放在输出目录)
  130. output_dir = os.path.dirname(output_file)
  131. tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_")
  132. pbar = tqdm(total=total, unit='行', desc='下载中')
  133. pbar_lock = threading.Lock()
  134. session_id = download_session.id
  135. tmp_files = {}
  136. def download_chunk(chunk_info):
  137. idx, start, count = chunk_info
  138. tmp_file = f"{tmp_prefix}{idx:04d}"
  139. session = tunnel.create_download_session(table.name, download_id=session_id)
  140. with session.open_arrow_reader(start, count) as reader:
  141. batches = []
  142. for batch in reader:
  143. batches.append(batch)
  144. with pbar_lock:
  145. pbar.update(batch.num_rows)
  146. if batches:
  147. tbl = pa.Table.from_batches(batches)
  148. pa_csv.write_csv(tbl, tmp_file)
  149. return idx, tmp_file
  150. # 并行下载
  151. with ThreadPoolExecutor(max_workers=workers) as executor:
  152. futures = [executor.submit(download_chunk, chunk) for chunk in chunks]
  153. for future in as_completed(futures):
  154. idx, tmp_file = future.result()
  155. tmp_files[idx] = tmp_file
  156. pbar.close()
  157. # 按顺序合并
  158. print("合并文件中...")
  159. with open(output_file, 'wb') as outf:
  160. for idx in range(len(chunks)):
  161. tmp_file = tmp_files.get(idx)
  162. if tmp_file and os.path.exists(tmp_file):
  163. with open(tmp_file, 'rb') as inf:
  164. if idx > 0:
  165. inf.readline() # 跳过表头
  166. outf.write(inf.read())
  167. os.remove(tmp_file)
  168. finally:
  169. # 6. 删除临时表
  170. print(f"删除临时表: {tmp_table}")
  171. self.odps.delete_table(tmp_table, if_exists=True)
  172. total_time = time.time() - start_time
  173. print(f"总耗时: {total_time:.1f}s")
  174. print(f"完成: {output_file}")
  175. # ──────────────────────────────────────────────────────────────────────────────
  176. # DataWorks 客户端:根据表名获取生产代码
  177. # ──────────────────────────────────────────────────────────────────────────────
  178. # 最佳实践链路:GetTable → GetTask → Script.Content
  179. # 1. GetTable(entity_id, include_business_metadata=True) 精确获取表的上游任务
  180. # 2. GetTask(task_id, project_env='Prod') 获取任务的 SQL 代码
  181. # ──────────────────────────────────────────────────────────────────────────────
  182. # 账号下所有可访问的 DataWorks 项目(project_id → name)
  183. _DW_PROJECTS = {
  184. 4858: "loghubods",
  185. 11300: "DWH",
  186. 5477: "videocdm",
  187. 548768: "piaoquan_api",
  188. 148813: "content_safety",
  189. 96094: "algo",
  190. 52578: "majin",
  191. 5057: "useractionbi",
  192. 5034: "user_video_action_cdm",
  193. 4868: "usercdm",
  194. 4859: "videoods",
  195. 6025: "videoads",
  196. 5535: "user_video_tag",
  197. 19288: "RecallEmbedding",
  198. 10762: "Test_model1",
  199. 193831: "cost_mgt_1894469520484605",
  200. 156474: "dyp_1",
  201. 156475: "dyp_2",
  202. 343868: "pq_data_space",
  203. 343957: "pq_grafana_se",
  204. }
  205. _CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "production_code")
  206. def _call_with_retry(fn, max_retries=3, base_delay=2):
  207. """带限流重试的 API 调用包装。"""
  208. for attempt in range(max_retries):
  209. try:
  210. return fn()
  211. except Exception as e:
  212. if "Throttling" in str(e) and attempt < max_retries - 1:
  213. delay = base_delay * (2 ** attempt)
  214. print(f" [throttled] 等待 {delay}s 后重试...")
  215. time.sleep(delay)
  216. continue
  217. raise
  218. class DataWorksClient:
  219. def __init__(self):
  220. if not _DW_AVAILABLE:
  221. raise ImportError(
  222. "请先安装 DataWorks SDK:\n"
  223. "pip install alibabacloud-dataworks-public20240518"
  224. )
  225. # 初始化所有 AK 对应的客户端(不同 AK 对不同项目有权限)
  226. self._clients = {}
  227. for config_name, cfg in ODPS_CONFIGS.items():
  228. dw_config = _open_api_models.Config(
  229. access_key_id=cfg["access_id"],
  230. access_key_secret=cfg["access_secret"],
  231. endpoint="dataworks.cn-hangzhou.aliyuncs.com",
  232. )
  233. self._clients[config_name] = _DWClient(dw_config)
  234. self.client = self._clients["default"]
  235. @staticmethod
  236. def _build_entity_id(table_name: str) -> str:
  237. """构造 GetTable 的 entity ID。
  238. 支持格式:
  239. - project.table → maxcompute-table:::project::table
  240. - table → maxcompute-table:::loghubods::table
  241. """
  242. parts = table_name.split(".", 1)
  243. if len(parts) == 2:
  244. project, table = parts
  245. else:
  246. project, table = "loghubods", parts[0]
  247. return f"maxcompute-table:::{project}::{table}"
  248. def get_table_info(self, table_name: str) -> dict:
  249. """获取表的元信息(含上游任务列表)。
  250. Returns:
  251. dict with keys: name, comment, dataworks_tasks[{id, name}], ...
  252. """
  253. entity_id = self._build_entity_id(table_name)
  254. resp = _call_with_retry(lambda: self.client.get_table(
  255. _dw_models.GetTableRequest(id=entity_id, include_business_metadata=True)
  256. ))
  257. table = resp.body.to_map().get("Table", {})
  258. biz = table.get("BusinessMetadata", {})
  259. return {
  260. "id": table.get("Id"),
  261. "name": table.get("Name"),
  262. "comment": table.get("Comment"),
  263. "project_id": biz.get("Extension", {}).get("ProjectId"),
  264. "dataworks_tasks": biz.get("UpstreamTasks", []),
  265. "partition_keys": table.get("PartitionKeys", []),
  266. }
  267. def _get_task_code(self, task_id: int) -> dict:
  268. """尝试用所有 AK 获取任务代码,返回第一个成功的结果。"""
  269. for config_name, client in self._clients.items():
  270. try:
  271. resp = _call_with_retry(lambda c=client: c.get_task(
  272. _dw_models.GetTaskRequest(id=task_id, project_env="Prod")
  273. ))
  274. task = resp.body.to_map().get("Task", {})
  275. return {
  276. "task_id": task_id,
  277. "task_name": task.get("Name"),
  278. "task_type": task.get("Type"),
  279. "content": (task.get("Script") or {}).get("Content", ""),
  280. "config": config_name,
  281. }
  282. except Exception as e:
  283. if "11020205003" in str(e):
  284. continue # 无权限,尝试下一个 AK
  285. raise
  286. return None
  287. @staticmethod
  288. def _normalize_table_name(table_name: str) -> str:
  289. """补全 project 前缀:table → loghubods.table"""
  290. if "." not in table_name:
  291. return f"loghubods.{table_name}"
  292. return table_name
  293. @staticmethod
  294. def _cache_path(table_name: str) -> str:
  295. return os.path.join(_CACHE_DIR, f"{table_name}.sql")
  296. @staticmethod
  297. def _schema_cache_path(table_name: str) -> str:
  298. return os.path.join(_CACHE_DIR, f"{table_name}.json")
  299. def _read_cache(self, table_name: str) -> str | None:
  300. path = self._cache_path(table_name)
  301. if os.path.exists(path):
  302. with open(path, "r", encoding="utf-8") as f:
  303. return f.read()
  304. return None
  305. def _write_cache(self, table_name: str, content: str):
  306. os.makedirs(_CACHE_DIR, exist_ok=True)
  307. with open(self._cache_path(table_name), "w", encoding="utf-8") as f:
  308. f.write(content)
  309. def _read_schema_cache(self, table_name: str) -> dict | None:
  310. import json
  311. path = self._schema_cache_path(table_name)
  312. if os.path.exists(path):
  313. with open(path, "r", encoding="utf-8") as f:
  314. return json.load(f)
  315. return None
  316. def _write_schema_cache(self, table_name: str, schema: dict):
  317. import json
  318. os.makedirs(_CACHE_DIR, exist_ok=True)
  319. with open(self._schema_cache_path(table_name), "w", encoding="utf-8") as f:
  320. json.dump(schema, f, ensure_ascii=False, indent=2)
  321. def _ensure_schema_cache(self, table_name: str, force: bool = False,
  322. dataworks_tasks: list | None = None):
  323. """确保 schema 缓存存在,无则拉取并写入。
  324. Args:
  325. dataworks_tasks: 预获取的上游任务列表,避免重复调用 get_table_info()
  326. """
  327. if not force:
  328. cached = self._read_schema_cache(table_name)
  329. if cached is not None:
  330. return
  331. try:
  332. schema = self.get_table_schema(table_name, dataworks_tasks=dataworks_tasks)
  333. self._write_schema_cache(table_name, schema)
  334. print(f"[saved] {self._schema_cache_path(table_name)}")
  335. except Exception as e:
  336. print(f"[WARN] 获取表结构失败 {table_name}: {e}")
  337. def get_table_schema(self, table_name: str,
  338. dataworks_tasks: list | None = None) -> dict:
  339. """通过 ODPS SDK 获取表结构元信息。
  340. Args:
  341. table_name: 表名(支持 project.table 格式)
  342. dataworks_tasks: 预获取的上游任务列表,避免重复 API 调用
  343. Returns:
  344. dict: {name, project, comment, columns, partition_keys, dataworks_tasks}
  345. """
  346. table_name = self._normalize_table_name(table_name)
  347. parts = table_name.split(".", 1)
  348. project, table = parts[0], parts[1]
  349. # 用默认 AK 对应的 ODPSClient 获取 ODPS 表结构
  350. odps_client = ODPSClient(project=project)
  351. t = odps_client.odps.get_table(table)
  352. columns = [
  353. {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
  354. for c in t.table_schema.columns
  355. ]
  356. partition_keys = [
  357. {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
  358. for c in t.table_schema.partitions
  359. ]
  360. # 上游任务:优先用传入的,否则从 DataWorks API 获取
  361. if dataworks_tasks is None:
  362. try:
  363. info = self.get_table_info(table_name)
  364. dataworks_tasks = [
  365. {"id": task.get("Id"), "name": task.get("Name")}
  366. for task in info.get("dataworks_tasks", [])
  367. ]
  368. except Exception:
  369. dataworks_tasks = []
  370. # 直接上游表(血缘)
  371. try:
  372. upstream_tables = self.get_upstream_tables(table_name)
  373. except Exception:
  374. upstream_tables = []
  375. return {
  376. "name": table,
  377. "project": project,
  378. "comment": t.comment or "",
  379. "columns": columns,
  380. "partition_keys": partition_keys,
  381. "dataworks_tasks": dataworks_tasks,
  382. "upstream_tables": upstream_tables,
  383. }
  384. def get_node_code(self, table_name: str, force: bool = False) -> list:
  385. """根据表名获取生产代码(优先读本地缓存)。
  386. 流程:本地缓存 → GetTable → GetTask → 写缓存 → 返回代码
  387. Args:
  388. table_name: 表名(支持 project.table 格式)
  389. force: True 时跳过缓存,强制从 API 拉取
  390. Returns:
  391. list of dict,每条包含:
  392. task_id, task_name, task_type, content
  393. """
  394. table_name = self._normalize_table_name(table_name)
  395. # 读缓存
  396. if not force:
  397. cached = self._read_cache(table_name)
  398. if cached is not None:
  399. print(f"[cache] {self._cache_path(table_name)}")
  400. # 同时检查 schema 缓存,无则补拉
  401. self._ensure_schema_cache(table_name, force=False)
  402. return [{"task_id": None, "task_name": "(cached)", "task_type": None, "content": cached}]
  403. # API 拉取
  404. info = self.get_table_info(table_name)
  405. upstream = info.get("dataworks_tasks", [])
  406. if not upstream:
  407. print(f"表 '{table_name}' 没有上游任务")
  408. return []
  409. results = []
  410. for task in upstream:
  411. task_id = task.get("Id")
  412. task_name = task.get("Name")
  413. result = self._get_task_code(task_id)
  414. if result:
  415. results.append(result)
  416. else:
  417. print(f"[WARN] 任务 {task_name}({task_id}) 所有 AK 均无权限")
  418. # 写缓存
  419. if results:
  420. parts = []
  421. for r in results:
  422. header = f"-- Task: {r['task_name']} ID: {r['task_id']} Type: {r['task_type']}"
  423. parts.append(f"{header}\n{r['content']}")
  424. self._write_cache(table_name, "\n\n".join(parts))
  425. print(f"[saved] {self._cache_path(table_name)}")
  426. # 获取并缓存 schema(复用已有的上游任务信息,避免重复 API 调用)
  427. up_tasks = [
  428. {"id": task.get("Id"), "name": task.get("Name")}
  429. for task in upstream
  430. ]
  431. self._ensure_schema_cache(table_name, force=force, dataworks_tasks=up_tasks)
  432. return results
  433. def get_upstream_tables(self, table_name: str) -> list[str]:
  434. """通过血缘 API 获取表的直接上游表列表。
  435. Returns:
  436. list of str,如 ["loghubods.user_share_log_flow", ...]
  437. """
  438. entity_id = self._build_entity_id(table_name)
  439. resp = _call_with_retry(lambda: self.client.list_lineages(
  440. _dw_models.ListLineagesRequest(dst_entity_id=entity_id, page_size=50)
  441. ))
  442. lineages = resp.body.to_map().get("PagingInfo", {}).get("Lineages", [])
  443. tables = []
  444. for l in lineages:
  445. src_id = l.get("SrcEntity", {}).get("Id", "")
  446. # maxcompute-table:::project::table → project.table
  447. parts = src_id.replace("maxcompute-table:::", "").split("::")
  448. if len(parts) == 2:
  449. tables.append(f"{parts[0]}.{parts[1]}")
  450. return sorted(set(tables))
  451. def get_node_code_recursive(self, table_name: str, max_depth: int = 3,
  452. force: bool = False) -> dict:
  453. """BFS 逐层获取表及其所有上游表的生产代码。
  454. 通过血缘 API(ListLineages)逐层追溯上游依赖,
  455. 每层的代码和上游表都会被缓存到 production_code/。
  456. Args:
  457. table_name: 表名(支持 project.table 格式)
  458. max_depth: 最大追溯层数,默认 3
  459. force: True 时跳过缓存
  460. Returns:
  461. dict: {
  462. "project.table": {
  463. "code": [...], # get_node_code 返回值
  464. "upstream": ["a.b", ...], # 上游表名列表
  465. "depth": int
  466. }, ...
  467. }
  468. """
  469. from collections import deque
  470. table_name = self._normalize_table_name(table_name)
  471. result = {}
  472. queue = deque([(table_name, 0)])
  473. visited = {table_name}
  474. while queue:
  475. tbl, depth = queue.popleft()
  476. indent = " " * depth
  477. print(f"{indent}[depth={depth}] {tbl}")
  478. # 获取代码
  479. code = self.get_node_code(tbl, force=force)
  480. # 获取上游表
  481. upstream = []
  482. if depth < max_depth:
  483. try:
  484. upstream = self.get_upstream_tables(tbl)
  485. except Exception:
  486. pass
  487. result[tbl] = {"code": code, "upstream": upstream, "depth": depth}
  488. # 下一层入队
  489. for up_tbl in upstream:
  490. if up_tbl not in visited:
  491. visited.add(up_tbl)
  492. queue.append((up_tbl, depth + 1))
  493. # 打印汇总
  494. print(f"\n共追溯 {len(result)} 张表:")
  495. for tbl, info in result.items():
  496. has_code = "有代码" if info["code"] else "无代码"
  497. n_up = len(info["upstream"])
  498. print(f" {' ' * info['depth']}{tbl} ({has_code}, {n_up} 个上游)")
  499. return result
  500. def print_node_code(self, table_name: str):
  501. """打印表的生产代码(人类可读格式)"""
  502. results = self.get_node_code(table_name)
  503. if not results:
  504. print(f"未找到 '{table_name}' 的生产代码")
  505. return
  506. for r in results:
  507. print(f"\n{'='*60}")
  508. print(f"任务: {r['task_name']} ID: {r['task_id']} "
  509. f"类型: {r['task_type']}")
  510. print(f"{'='*60}")
  511. print(r["content"] or "(无内容)")