odps_module.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. import os
  4. import time
  5. import uuid
  6. import threading
  7. from concurrent.futures import ThreadPoolExecutor, as_completed
  8. from odps import ODPS, options
  9. from odps.tunnel import TableTunnel
  10. from tqdm import tqdm
  11. import pyarrow as pa
  12. from pyarrow import csv as pa_csv
  13. # DataWorks SDK(可选依赖,仅 DataWorksClient 用到)
  14. try:
  15. from alibabacloud_dataworks_public20240518.client import Client as _DWClient
  16. from alibabacloud_tea_openapi import models as _open_api_models
  17. from alibabacloud_dataworks_public20240518 import models as _dw_models
  18. _DW_AVAILABLE = True
  19. except ImportError:
  20. _DW_AVAILABLE = False
  21. # 开启 Instance Tunnel,解除 1 万条限制
  22. options.tunnel.use_instance_tunnel = True
  23. options.tunnel.limit_instance_tunnel = False
  24. # ODPS 配置
  25. ODPS_CONFIGS = {
  26. "default": {
  27. "access_id": "LTAIWYUujJAm7CbH",
  28. "access_secret": "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P",
  29. "project": "loghubods",
  30. },
  31. "piaoquan_api": {
  32. "access_id": "LTAI5tKyXxh7C6349c1wbwUX",
  33. "access_secret": "H8doQDC20KugToRA3giERgRyRD1KR9",
  34. "project": "piaoquan_api",
  35. },
  36. }
  37. class ODPSClient(object):
  38. def __init__(self, project="loghubods", config="default"):
  39. """
  40. 初始化 ODPS 客户端
  41. Args:
  42. project: 项目名(可覆盖配置中的默认项目)
  43. config: 配置名,可选 "default" 或 "piaoquan_api"
  44. """
  45. cfg = ODPS_CONFIGS.get(config, ODPS_CONFIGS["default"])
  46. self.accessId = cfg["access_id"]
  47. self.accessSecret = cfg["access_secret"]
  48. self.endpoint = "http://service.odps.aliyun.com/api"
  49. self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com"
  50. # 如果指定了 project 且不是默认值,使用指定的;否则用配置中的
  51. actual_project = project if project != "loghubods" else cfg["project"]
  52. self.odps = ODPS(
  53. self.accessId,
  54. self.accessSecret,
  55. actual_project,
  56. self.endpoint
  57. )
  58. def execute_sql(self, sql: str, print_logview: bool = True):
  59. """执行 SQL 并返回 DataFrame"""
  60. hints = {'odps.sql.submit.mode': 'script'}
  61. instance = self.odps.execute_sql(sql, hints=hints)
  62. if print_logview:
  63. print(f"LogView: {instance.get_logview_address()}")
  64. with instance.open_reader(tunnel=True, limit=False) as reader:
  65. pd_df = reader.to_pandas()
  66. return pd_df
  67. def execute_sql_result_save_file(self, sql: str, output_file: str):
  68. """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快)"""
  69. hints = {'odps.sql.submit.mode': 'script'}
  70. start_time = time.time()
  71. instance = self.odps.execute_sql(sql, hints=hints)
  72. sql_time = time.time() - start_time
  73. print(f"LogView: {instance.get_logview_address()}")
  74. print(f"SQL 执行耗时: {sql_time:.1f}s")
  75. with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader:
  76. total = reader.count
  77. # 边下载边写入,用 pyarrow 直接写 CSV
  78. with open(output_file, 'wb') as f:
  79. first = True
  80. with tqdm(total=total, unit='行', desc='下载中') as pbar:
  81. for batch in reader:
  82. # pyarrow 写 CSV(比 pandas 快很多)
  83. options = pa_csv.WriteOptions(include_header=first)
  84. pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options)
  85. first = False
  86. pbar.update(batch.num_rows)
  87. total_time = time.time() - start_time
  88. print(f"总耗时: {total_time:.1f}s")
  89. print(f"完成: {output_file}")
  90. def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4):
  91. """执行 SQL 并保存到文件(多线程并行下载,速度最快)"""
  92. hints = {'odps.sql.submit.mode': 'script'}
  93. # 生成临时表名
  94. tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}"
  95. create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}"
  96. start_time = time.time()
  97. # 1. 创建临时表
  98. print(f"创建临时表: {tmp_table}")
  99. instance = self.odps.execute_sql(create_sql, hints=hints)
  100. print(f"LogView: {instance.get_logview_address()}")
  101. instance.wait_for_success()
  102. sql_time = time.time() - start_time
  103. print(f"SQL 执行耗时: {sql_time:.1f}s")
  104. try:
  105. # 2. 获取表信息
  106. table = self.odps.get_table(tmp_table)
  107. tunnel = TableTunnel(self.odps)
  108. download_session = tunnel.create_download_session(table.name)
  109. total = download_session.count
  110. print(f"总行数: {total}")
  111. if total == 0:
  112. # 空表,直接写入空 CSV
  113. with open(output_file, 'w') as f:
  114. columns = [col.name for col in table.table_schema.columns]
  115. f.write(','.join(columns) + '\n')
  116. print(f"完成: {output_file} (空表)")
  117. return
  118. # 3. 分段
  119. chunk_size = (total + workers - 1) // workers
  120. chunks = []
  121. for i in range(workers):
  122. start = i * chunk_size
  123. end = min((i + 1) * chunk_size, total)
  124. if start < end:
  125. chunks.append((i, start, end - start)) # (index, start, count)
  126. print(f"并行下载: {len(chunks)} 个分片, {workers} 线程")
  127. # 4. 多线程下载到临时文件(放在输出目录)
  128. output_dir = os.path.dirname(output_file)
  129. tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_")
  130. pbar = tqdm(total=total, unit='行', desc='下载中')
  131. pbar_lock = threading.Lock()
  132. session_id = download_session.id
  133. tmp_files = {}
  134. def download_chunk(chunk_info):
  135. idx, start, count = chunk_info
  136. tmp_file = f"{tmp_prefix}{idx:04d}"
  137. session = tunnel.create_download_session(table.name, download_id=session_id)
  138. with session.open_arrow_reader(start, count) as reader:
  139. batches = []
  140. for batch in reader:
  141. batches.append(batch)
  142. with pbar_lock:
  143. pbar.update(batch.num_rows)
  144. if batches:
  145. tbl = pa.Table.from_batches(batches)
  146. pa_csv.write_csv(tbl, tmp_file)
  147. return idx, tmp_file
  148. # 并行下载
  149. with ThreadPoolExecutor(max_workers=workers) as executor:
  150. futures = [executor.submit(download_chunk, chunk) for chunk in chunks]
  151. for future in as_completed(futures):
  152. idx, tmp_file = future.result()
  153. tmp_files[idx] = tmp_file
  154. pbar.close()
  155. # 按顺序合并
  156. print("合并文件中...")
  157. with open(output_file, 'wb') as outf:
  158. for idx in range(len(chunks)):
  159. tmp_file = tmp_files.get(idx)
  160. if tmp_file and os.path.exists(tmp_file):
  161. with open(tmp_file, 'rb') as inf:
  162. if idx > 0:
  163. inf.readline() # 跳过表头
  164. outf.write(inf.read())
  165. os.remove(tmp_file)
  166. finally:
  167. # 6. 删除临时表
  168. print(f"删除临时表: {tmp_table}")
  169. self.odps.delete_table(tmp_table, if_exists=True)
  170. total_time = time.time() - start_time
  171. print(f"总耗时: {total_time:.1f}s")
  172. print(f"完成: {output_file}")
  173. # ──────────────────────────────────────────────────────────────────────────────
  174. # DataWorks 客户端:根据表名获取生产代码
  175. # ──────────────────────────────────────────────────────────────────────────────
  176. # 最佳实践链路:GetTable → GetTask → Script.Content
  177. # 1. GetTable(entity_id, include_business_metadata=True) 精确获取表的上游任务
  178. # 2. GetTask(task_id, project_env='Prod') 获取任务的 SQL 代码
  179. # ──────────────────────────────────────────────────────────────────────────────
  180. # 账号下所有可访问的 DataWorks 项目(project_id → name)
  181. _DW_PROJECTS = {
  182. 4858: "loghubods",
  183. 11300: "DWH",
  184. 5477: "videocdm",
  185. 548768: "piaoquan_api",
  186. 148813: "content_safety",
  187. 96094: "algo",
  188. 52578: "majin",
  189. 5057: "useractionbi",
  190. 5034: "user_video_action_cdm",
  191. 4868: "usercdm",
  192. 4859: "videoods",
  193. 6025: "videoads",
  194. 5535: "user_video_tag",
  195. 19288: "RecallEmbedding",
  196. 10762: "Test_model1",
  197. 193831: "cost_mgt_1894469520484605",
  198. 156474: "dyp_1",
  199. 156475: "dyp_2",
  200. 343868: "pq_data_space",
  201. 343957: "pq_grafana_se",
  202. }
  203. _CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "production_code")
  204. def _call_with_retry(fn, max_retries=3, base_delay=2):
  205. """带限流重试的 API 调用包装。"""
  206. for attempt in range(max_retries):
  207. try:
  208. return fn()
  209. except Exception as e:
  210. if "Throttling" in str(e) and attempt < max_retries - 1:
  211. delay = base_delay * (2 ** attempt)
  212. print(f" [throttled] 等待 {delay}s 后重试...")
  213. time.sleep(delay)
  214. continue
  215. raise
  216. class DataWorksClient:
  217. def __init__(self):
  218. if not _DW_AVAILABLE:
  219. raise ImportError(
  220. "请先安装 DataWorks SDK:\n"
  221. "pip install alibabacloud-dataworks-public20240518"
  222. )
  223. # 初始化所有 AK 对应的客户端(不同 AK 对不同项目有权限)
  224. self._clients = {}
  225. for config_name, cfg in ODPS_CONFIGS.items():
  226. dw_config = _open_api_models.Config(
  227. access_key_id=cfg["access_id"],
  228. access_key_secret=cfg["access_secret"],
  229. endpoint="dataworks.cn-hangzhou.aliyuncs.com",
  230. )
  231. self._clients[config_name] = _DWClient(dw_config)
  232. self.client = self._clients["default"]
  233. @staticmethod
  234. def _build_entity_id(table_name: str) -> str:
  235. """构造 GetTable 的 entity ID。
  236. 支持格式:
  237. - project.table → maxcompute-table:::project::table
  238. - table → maxcompute-table:::loghubods::table
  239. """
  240. parts = table_name.split(".", 1)
  241. if len(parts) == 2:
  242. project, table = parts
  243. else:
  244. project, table = "loghubods", parts[0]
  245. return f"maxcompute-table:::{project}::{table}"
  246. def get_table_info(self, table_name: str) -> dict:
  247. """获取表的元信息(含上游任务列表)。
  248. Returns:
  249. dict with keys: name, comment, dataworks_tasks[{id, name}], ...
  250. """
  251. entity_id = self._build_entity_id(table_name)
  252. resp = _call_with_retry(lambda: self.client.get_table(
  253. _dw_models.GetTableRequest(id=entity_id, include_business_metadata=True)
  254. ))
  255. table = resp.body.to_map().get("Table", {})
  256. biz = table.get("BusinessMetadata", {})
  257. return {
  258. "id": table.get("Id"),
  259. "name": table.get("Name"),
  260. "comment": table.get("Comment"),
  261. "project_id": biz.get("Extension", {}).get("ProjectId"),
  262. "dataworks_tasks": biz.get("UpstreamTasks", []),
  263. "partition_keys": table.get("PartitionKeys", []),
  264. }
  265. def _get_task_code(self, task_id: int) -> dict:
  266. """尝试用所有 AK 获取任务代码,返回第一个成功的结果。"""
  267. for config_name, client in self._clients.items():
  268. try:
  269. resp = _call_with_retry(lambda c=client: c.get_task(
  270. _dw_models.GetTaskRequest(id=task_id, project_env="Prod")
  271. ))
  272. task = resp.body.to_map().get("Task", {})
  273. return {
  274. "task_id": task_id,
  275. "task_name": task.get("Name"),
  276. "task_type": task.get("Type"),
  277. "content": (task.get("Script") or {}).get("Content", ""),
  278. "config": config_name,
  279. }
  280. except Exception as e:
  281. if "11020205003" in str(e):
  282. continue # 无权限,尝试下一个 AK
  283. raise
  284. return None
  285. @staticmethod
  286. def _normalize_table_name(table_name: str) -> str:
  287. """补全 project 前缀:table → loghubods.table"""
  288. if "." not in table_name:
  289. return f"loghubods.{table_name}"
  290. return table_name
  291. @staticmethod
  292. def _cache_path(table_name: str) -> str:
  293. return os.path.join(_CACHE_DIR, f"{table_name}.sql")
  294. @staticmethod
  295. def _schema_cache_path(table_name: str) -> str:
  296. return os.path.join(_CACHE_DIR, f"{table_name}.json")
  297. def _read_cache(self, table_name: str) -> str | None:
  298. path = self._cache_path(table_name)
  299. if os.path.exists(path):
  300. with open(path, "r", encoding="utf-8") as f:
  301. return f.read()
  302. return None
  303. def _write_cache(self, table_name: str, content: str):
  304. os.makedirs(_CACHE_DIR, exist_ok=True)
  305. with open(self._cache_path(table_name), "w", encoding="utf-8") as f:
  306. f.write(content)
  307. def _read_schema_cache(self, table_name: str) -> dict | None:
  308. import json
  309. path = self._schema_cache_path(table_name)
  310. if os.path.exists(path):
  311. with open(path, "r", encoding="utf-8") as f:
  312. return json.load(f)
  313. return None
  314. def _write_schema_cache(self, table_name: str, schema: dict):
  315. import json
  316. os.makedirs(_CACHE_DIR, exist_ok=True)
  317. with open(self._schema_cache_path(table_name), "w", encoding="utf-8") as f:
  318. json.dump(schema, f, ensure_ascii=False, indent=2)
  319. def _ensure_schema_cache(self, table_name: str, force: bool = False,
  320. dataworks_tasks: list | None = None):
  321. """确保 schema 缓存存在,无则拉取并写入。
  322. Args:
  323. dataworks_tasks: 预获取的上游任务列表,避免重复调用 get_table_info()
  324. """
  325. if not force:
  326. cached = self._read_schema_cache(table_name)
  327. if cached is not None:
  328. return
  329. try:
  330. schema = self.get_table_schema(table_name, dataworks_tasks=dataworks_tasks)
  331. self._write_schema_cache(table_name, schema)
  332. print(f"[saved] {self._schema_cache_path(table_name)}")
  333. except Exception as e:
  334. print(f"[WARN] 获取表结构失败 {table_name}: {e}")
  335. def get_table_schema(self, table_name: str,
  336. dataworks_tasks: list | None = None) -> dict:
  337. """通过 ODPS SDK 获取表结构元信息。
  338. Args:
  339. table_name: 表名(支持 project.table 格式)
  340. dataworks_tasks: 预获取的上游任务列表,避免重复 API 调用
  341. Returns:
  342. dict: {name, project, comment, columns, partition_keys, dataworks_tasks}
  343. """
  344. table_name = self._normalize_table_name(table_name)
  345. parts = table_name.split(".", 1)
  346. project, table = parts[0], parts[1]
  347. # 用默认 AK 对应的 ODPSClient 获取 ODPS 表结构
  348. odps_client = ODPSClient(project=project)
  349. t = odps_client.odps.get_table(table)
  350. columns = [
  351. {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
  352. for c in t.table_schema.columns
  353. ]
  354. partition_keys = [
  355. {"name": c.name, "type": str(c.type), "comment": c.comment or ""}
  356. for c in t.table_schema.partitions
  357. ]
  358. # 上游任务:优先用传入的,否则从 DataWorks API 获取
  359. if dataworks_tasks is None:
  360. try:
  361. info = self.get_table_info(table_name)
  362. dataworks_tasks = [
  363. {"id": task.get("Id"), "name": task.get("Name")}
  364. for task in info.get("dataworks_tasks", [])
  365. ]
  366. except Exception:
  367. dataworks_tasks = []
  368. # 直接上游表(血缘)
  369. try:
  370. upstream_tables = self.get_upstream_tables(table_name)
  371. except Exception:
  372. upstream_tables = []
  373. return {
  374. "name": table,
  375. "project": project,
  376. "comment": t.comment or "",
  377. "columns": columns,
  378. "partition_keys": partition_keys,
  379. "dataworks_tasks": dataworks_tasks,
  380. "upstream_tables": upstream_tables,
  381. }
  382. def get_node_code(self, table_name: str, force: bool = False) -> list:
  383. """根据表名获取生产代码(优先读本地缓存)。
  384. 流程:本地缓存 → GetTable → GetTask → 写缓存 → 返回代码
  385. Args:
  386. table_name: 表名(支持 project.table 格式)
  387. force: True 时跳过缓存,强制从 API 拉取
  388. Returns:
  389. list of dict,每条包含:
  390. task_id, task_name, task_type, content
  391. """
  392. table_name = self._normalize_table_name(table_name)
  393. # 读缓存
  394. if not force:
  395. cached = self._read_cache(table_name)
  396. if cached is not None:
  397. print(f"[cache] {self._cache_path(table_name)}")
  398. # 同时检查 schema 缓存,无则补拉
  399. self._ensure_schema_cache(table_name, force=False)
  400. return [{"task_id": None, "task_name": "(cached)", "task_type": None, "content": cached}]
  401. # API 拉取
  402. info = self.get_table_info(table_name)
  403. upstream = info.get("dataworks_tasks", [])
  404. if not upstream:
  405. print(f"表 '{table_name}' 没有上游任务")
  406. return []
  407. results = []
  408. for task in upstream:
  409. task_id = task.get("Id")
  410. task_name = task.get("Name")
  411. result = self._get_task_code(task_id)
  412. if result:
  413. results.append(result)
  414. else:
  415. print(f"[WARN] 任务 {task_name}({task_id}) 所有 AK 均无权限")
  416. # 写缓存
  417. if results:
  418. parts = []
  419. for r in results:
  420. header = f"-- Task: {r['task_name']} ID: {r['task_id']} Type: {r['task_type']}"
  421. parts.append(f"{header}\n{r['content']}")
  422. self._write_cache(table_name, "\n\n".join(parts))
  423. print(f"[saved] {self._cache_path(table_name)}")
  424. # 获取并缓存 schema(复用已有的上游任务信息,避免重复 API 调用)
  425. up_tasks = [
  426. {"id": task.get("Id"), "name": task.get("Name")}
  427. for task in upstream
  428. ]
  429. self._ensure_schema_cache(table_name, force=force, dataworks_tasks=up_tasks)
  430. return results
  431. def get_upstream_tables(self, table_name: str) -> list[str]:
  432. """通过血缘 API 获取表的直接上游表列表。
  433. Returns:
  434. list of str,如 ["loghubods.user_share_log_flow", ...]
  435. """
  436. entity_id = self._build_entity_id(table_name)
  437. resp = _call_with_retry(lambda: self.client.list_lineages(
  438. _dw_models.ListLineagesRequest(dst_entity_id=entity_id, page_size=50)
  439. ))
  440. lineages = resp.body.to_map().get("PagingInfo", {}).get("Lineages", [])
  441. tables = []
  442. for l in lineages:
  443. src_id = l.get("SrcEntity", {}).get("Id", "")
  444. # maxcompute-table:::project::table → project.table
  445. parts = src_id.replace("maxcompute-table:::", "").split("::")
  446. if len(parts) == 2:
  447. tables.append(f"{parts[0]}.{parts[1]}")
  448. return sorted(set(tables))
  449. def get_node_code_recursive(self, table_name: str, max_depth: int = 3,
  450. force: bool = False) -> dict:
  451. """BFS 逐层获取表及其所有上游表的生产代码。
  452. 通过血缘 API(ListLineages)逐层追溯上游依赖,
  453. 每层的代码和上游表都会被缓存到 production_code/。
  454. Args:
  455. table_name: 表名(支持 project.table 格式)
  456. max_depth: 最大追溯层数,默认 3
  457. force: True 时跳过缓存
  458. Returns:
  459. dict: {
  460. "project.table": {
  461. "code": [...], # get_node_code 返回值
  462. "upstream": ["a.b", ...], # 上游表名列表
  463. "depth": int
  464. }, ...
  465. }
  466. """
  467. from collections import deque
  468. table_name = self._normalize_table_name(table_name)
  469. result = {}
  470. queue = deque([(table_name, 0)])
  471. visited = {table_name}
  472. while queue:
  473. tbl, depth = queue.popleft()
  474. indent = " " * depth
  475. print(f"{indent}[depth={depth}] {tbl}")
  476. # 获取代码
  477. code = self.get_node_code(tbl, force=force)
  478. # 获取上游表
  479. upstream = []
  480. if depth < max_depth:
  481. try:
  482. upstream = self.get_upstream_tables(tbl)
  483. except Exception:
  484. pass
  485. result[tbl] = {"code": code, "upstream": upstream, "depth": depth}
  486. # 下一层入队
  487. for up_tbl in upstream:
  488. if up_tbl not in visited:
  489. visited.add(up_tbl)
  490. queue.append((up_tbl, depth + 1))
  491. # 打印汇总
  492. print(f"\n共追溯 {len(result)} 张表:")
  493. for tbl, info in result.items():
  494. has_code = "有代码" if info["code"] else "无代码"
  495. n_up = len(info["upstream"])
  496. print(f" {' ' * info['depth']}{tbl} ({has_code}, {n_up} 个上游)")
  497. return result
  498. def print_node_code(self, table_name: str):
  499. """打印表的生产代码(人类可读格式)"""
  500. results = self.get_node_code(table_name)
  501. if not results:
  502. print(f"未找到 '{table_name}' 的生产代码")
  503. return
  504. for r in results:
  505. print(f"\n{'='*60}")
  506. print(f"任务: {r['task_name']} ID: {r['task_id']} "
  507. f"类型: {r['task_type']}")
  508. print(f"{'='*60}")
  509. print(r["content"] or "(无内容)")