odps_module.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. import os
  4. import time
  5. import uuid
  6. import threading
  7. from concurrent.futures import ThreadPoolExecutor, as_completed
  8. from odps import ODPS, options
  9. from odps.tunnel import TableTunnel
  10. from tqdm import tqdm
  11. import pyarrow as pa
  12. from pyarrow import csv as pa_csv
  13. # 开启 Instance Tunnel,解除 1 万条限制
  14. options.tunnel.use_instance_tunnel = True
  15. options.tunnel.limit_instance_tunnel = False
  16. class ODPSClient(object):
  17. def __init__(self, project="loghubods"):
  18. self.accessId = "LTAIWYUujJAm7CbH"
  19. self.accessSecret = "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P"
  20. self.endpoint = "http://service.odps.aliyun.com/api"
  21. self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com"
  22. self.odps = ODPS(
  23. self.accessId,
  24. self.accessSecret,
  25. project,
  26. self.endpoint
  27. )
  28. def execute_sql(self, sql: str, print_logview: bool = True):
  29. """执行 SQL 并返回 DataFrame"""
  30. hints = {'odps.sql.submit.mode': 'script'}
  31. instance = self.odps.execute_sql(sql, hints=hints)
  32. if print_logview:
  33. print(f"LogView: {instance.get_logview_address()}")
  34. with instance.open_reader(tunnel=True, limit=False) as reader:
  35. pd_df = reader.to_pandas()
  36. return pd_df
  37. def execute_sql_result_save_file(self, sql: str, output_file: str):
  38. """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快)"""
  39. hints = {'odps.sql.submit.mode': 'script'}
  40. start_time = time.time()
  41. instance = self.odps.execute_sql(sql, hints=hints)
  42. sql_time = time.time() - start_time
  43. print(f"LogView: {instance.get_logview_address()}")
  44. print(f"SQL 执行耗时: {sql_time:.1f}s")
  45. with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader:
  46. total = reader.count
  47. # 边下载边写入,用 pyarrow 直接写 CSV
  48. with open(output_file, 'wb') as f:
  49. first = True
  50. with tqdm(total=total, unit='行', desc='下载中') as pbar:
  51. for batch in reader:
  52. # pyarrow 写 CSV(比 pandas 快很多)
  53. options = pa_csv.WriteOptions(include_header=first)
  54. pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options)
  55. first = False
  56. pbar.update(batch.num_rows)
  57. total_time = time.time() - start_time
  58. print(f"总耗时: {total_time:.1f}s")
  59. print(f"完成: {output_file}")
  60. def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4):
  61. """执行 SQL 并保存到文件(多线程并行下载,速度最快)"""
  62. hints = {'odps.sql.submit.mode': 'script'}
  63. # 生成临时表名
  64. tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}"
  65. create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}"
  66. start_time = time.time()
  67. # 1. 创建临时表
  68. print(f"创建临时表: {tmp_table}")
  69. instance = self.odps.execute_sql(create_sql, hints=hints)
  70. print(f"LogView: {instance.get_logview_address()}")
  71. instance.wait_for_success()
  72. sql_time = time.time() - start_time
  73. print(f"SQL 执行耗时: {sql_time:.1f}s")
  74. try:
  75. # 2. 获取表信息
  76. table = self.odps.get_table(tmp_table)
  77. tunnel = TableTunnel(self.odps)
  78. download_session = tunnel.create_download_session(table.name)
  79. total = download_session.count
  80. print(f"总行数: {total}")
  81. if total == 0:
  82. # 空表,直接写入空 CSV
  83. with open(output_file, 'w') as f:
  84. columns = [col.name for col in table.table_schema.columns]
  85. f.write(','.join(columns) + '\n')
  86. print(f"完成: {output_file} (空表)")
  87. return
  88. # 3. 分段
  89. chunk_size = (total + workers - 1) // workers
  90. chunks = []
  91. for i in range(workers):
  92. start = i * chunk_size
  93. end = min((i + 1) * chunk_size, total)
  94. if start < end:
  95. chunks.append((i, start, end - start)) # (index, start, count)
  96. print(f"并行下载: {len(chunks)} 个分片, {workers} 线程")
  97. # 4. 多线程下载到临时文件(放在输出目录)
  98. output_dir = os.path.dirname(output_file)
  99. tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_")
  100. pbar = tqdm(total=total, unit='行', desc='下载中')
  101. pbar_lock = threading.Lock()
  102. session_id = download_session.id
  103. tmp_files = {}
  104. def download_chunk(chunk_info):
  105. idx, start, count = chunk_info
  106. tmp_file = f"{tmp_prefix}{idx:04d}"
  107. session = tunnel.create_download_session(table.name, download_id=session_id)
  108. with session.open_arrow_reader(start, count) as reader:
  109. batches = []
  110. for batch in reader:
  111. batches.append(batch)
  112. with pbar_lock:
  113. pbar.update(batch.num_rows)
  114. if batches:
  115. tbl = pa.Table.from_batches(batches)
  116. pa_csv.write_csv(tbl, tmp_file)
  117. return idx, tmp_file
  118. # 并行下载
  119. with ThreadPoolExecutor(max_workers=workers) as executor:
  120. futures = [executor.submit(download_chunk, chunk) for chunk in chunks]
  121. for future in as_completed(futures):
  122. idx, tmp_file = future.result()
  123. tmp_files[idx] = tmp_file
  124. pbar.close()
  125. # 按顺序合并
  126. print("合并文件中...")
  127. with open(output_file, 'wb') as outf:
  128. for idx in range(len(chunks)):
  129. tmp_file = tmp_files.get(idx)
  130. if tmp_file and os.path.exists(tmp_file):
  131. with open(tmp_file, 'rb') as inf:
  132. if idx > 0:
  133. inf.readline() # 跳过表头
  134. outf.write(inf.read())
  135. os.remove(tmp_file)
  136. finally:
  137. # 6. 删除临时表
  138. print(f"删除临时表: {tmp_table}")
  139. self.odps.delete_table(tmp_table, if_exists=True)
  140. total_time = time.time() - start_time
  141. print(f"总耗时: {total_time:.1f}s")
  142. print(f"完成: {output_file}")