|
|
@@ -1,7 +1,20 @@
|
|
|
#!/usr/bin/env python
|
|
|
# coding=utf-8
|
|
|
|
|
|
-from odps import ODPS
|
|
|
+import os
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+import threading
|
|
|
+from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
+from odps import ODPS, options
|
|
|
+from odps.tunnel import TableTunnel
|
|
|
+from tqdm import tqdm
|
|
|
+import pyarrow as pa
|
|
|
+from pyarrow import csv as pa_csv
|
|
|
+
|
|
|
+# 开启 Instance Tunnel,解除 1 万条限制
|
|
|
+options.tunnel.use_instance_tunnel = True
|
|
|
+options.tunnel.limit_instance_tunnel = False
|
|
|
|
|
|
|
|
|
class ODPSClient(object):
|
|
|
@@ -18,14 +31,141 @@ class ODPSClient(object):
|
|
|
self.endpoint
|
|
|
)
|
|
|
|
|
|
- def execute_sql(self, sql: str):
|
|
|
- hints = {
|
|
|
- 'odps.sql.submit.mode': 'script'
|
|
|
- }
|
|
|
- with self.odps.execute_sql(sql, hints=hints).open_reader(tunnel=True) as reader:
|
|
|
+ def execute_sql(self, sql: str, print_logview: bool = True):
|
|
|
+ """执行 SQL 并返回 DataFrame"""
|
|
|
+ hints = {'odps.sql.submit.mode': 'script'}
|
|
|
+ instance = self.odps.execute_sql(sql, hints=hints)
|
|
|
+
|
|
|
+ if print_logview:
|
|
|
+ print(f"LogView: {instance.get_logview_address()}")
|
|
|
+
|
|
|
+ with instance.open_reader(tunnel=True, limit=False) as reader:
|
|
|
pd_df = reader.to_pandas()
|
|
|
return pd_df
|
|
|
|
|
|
def execute_sql_result_save_file(self, sql: str, output_file: str):
|
|
|
- data_df = self.execute_sql(sql)
|
|
|
- data_df.to_csv(output_file, index=False)
|
|
|
+ """执行 SQL 并保存到文件(Arrow 直接写 CSV,速度最快)"""
|
|
|
+ hints = {'odps.sql.submit.mode': 'script'}
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+ instance = self.odps.execute_sql(sql, hints=hints)
|
|
|
+ sql_time = time.time() - start_time
|
|
|
+
|
|
|
+ print(f"LogView: {instance.get_logview_address()}")
|
|
|
+ print(f"SQL 执行耗时: {sql_time:.1f}s")
|
|
|
+
|
|
|
+ with instance.open_reader(tunnel=True, limit=False, arrow=True) as reader:
|
|
|
+ total = reader.count
|
|
|
+
|
|
|
+ # 边下载边写入,用 pyarrow 直接写 CSV
|
|
|
+ with open(output_file, 'wb') as f:
|
|
|
+ first = True
|
|
|
+ with tqdm(total=total, unit='行', desc='下载中') as pbar:
|
|
|
+ for batch in reader:
|
|
|
+ # pyarrow 写 CSV(比 pandas 快很多)
|
|
|
+ options = pa_csv.WriteOptions(include_header=first)
|
|
|
+ pa_csv.write_csv(pa.Table.from_batches([batch]), f, write_options=options)
|
|
|
+ first = False
|
|
|
+ pbar.update(batch.num_rows)
|
|
|
+
|
|
|
+ total_time = time.time() - start_time
|
|
|
+ print(f"总耗时: {total_time:.1f}s")
|
|
|
+ print(f"完成: {output_file}")
|
|
|
+
|
|
|
+ def execute_sql_result_save_file_parallel(self, sql: str, output_file: str, workers: int = 4):
|
|
|
+ """执行 SQL 并保存到文件(多线程并行下载,速度最快)"""
|
|
|
+ hints = {'odps.sql.submit.mode': 'script'}
|
|
|
+
|
|
|
+ # 生成临时表名
|
|
|
+ tmp_table = f"tmp_download_{uuid.uuid4().hex[:8]}"
|
|
|
+ create_sql = f"CREATE TABLE {tmp_table} LIFECYCLE 1 AS {sql}"
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 1. 创建临时表
|
|
|
+ print(f"创建临时表: {tmp_table}")
|
|
|
+ instance = self.odps.execute_sql(create_sql, hints=hints)
|
|
|
+ print(f"LogView: {instance.get_logview_address()}")
|
|
|
+ instance.wait_for_success()
|
|
|
+ sql_time = time.time() - start_time
|
|
|
+ print(f"SQL 执行耗时: {sql_time:.1f}s")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 2. 获取表信息
|
|
|
+ table = self.odps.get_table(tmp_table)
|
|
|
+ tunnel = TableTunnel(self.odps)
|
|
|
+ download_session = tunnel.create_download_session(table.name)
|
|
|
+ total = download_session.count
|
|
|
+ print(f"总行数: {total}")
|
|
|
+
|
|
|
+ if total == 0:
|
|
|
+ # 空表,直接写入空 CSV
|
|
|
+ with open(output_file, 'w') as f:
|
|
|
+ columns = [col.name for col in table.table_schema.columns]
|
|
|
+ f.write(','.join(columns) + '\n')
|
|
|
+ print(f"完成: {output_file} (空表)")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 3. 分段
|
|
|
+ chunk_size = (total + workers - 1) // workers
|
|
|
+ chunks = []
|
|
|
+ for i in range(workers):
|
|
|
+ start = i * chunk_size
|
|
|
+ end = min((i + 1) * chunk_size, total)
|
|
|
+ if start < end:
|
|
|
+ chunks.append((i, start, end - start)) # (index, start, count)
|
|
|
+
|
|
|
+ print(f"并行下载: {len(chunks)} 个分片, {workers} 线程")
|
|
|
+
|
|
|
+ # 4. 多线程下载到临时文件(放在输出目录)
|
|
|
+ output_dir = os.path.dirname(output_file)
|
|
|
+ tmp_prefix = os.path.join(output_dir, f".tmp_{os.path.basename(output_file)}_")
|
|
|
+ pbar = tqdm(total=total, unit='行', desc='下载中')
|
|
|
+ pbar_lock = threading.Lock()
|
|
|
+ session_id = download_session.id
|
|
|
+ tmp_files = {}
|
|
|
+
|
|
|
+ def download_chunk(chunk_info):
|
|
|
+ idx, start, count = chunk_info
|
|
|
+ tmp_file = f"{tmp_prefix}{idx:04d}"
|
|
|
+ session = tunnel.create_download_session(table.name, download_id=session_id)
|
|
|
+ with session.open_arrow_reader(start, count) as reader:
|
|
|
+ batches = []
|
|
|
+ for batch in reader:
|
|
|
+ batches.append(batch)
|
|
|
+ with pbar_lock:
|
|
|
+ pbar.update(batch.num_rows)
|
|
|
+ if batches:
|
|
|
+ tbl = pa.Table.from_batches(batches)
|
|
|
+ pa_csv.write_csv(tbl, tmp_file)
|
|
|
+ return idx, tmp_file
|
|
|
+
|
|
|
+ # 并行下载
|
|
|
+ with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
|
+ futures = [executor.submit(download_chunk, chunk) for chunk in chunks]
|
|
|
+ for future in as_completed(futures):
|
|
|
+ idx, tmp_file = future.result()
|
|
|
+ tmp_files[idx] = tmp_file
|
|
|
+
|
|
|
+ pbar.close()
|
|
|
+
|
|
|
+ # 按顺序合并
|
|
|
+ print("合并文件中...")
|
|
|
+ with open(output_file, 'wb') as outf:
|
|
|
+ for idx in range(len(chunks)):
|
|
|
+ tmp_file = tmp_files.get(idx)
|
|
|
+ if tmp_file and os.path.exists(tmp_file):
|
|
|
+ with open(tmp_file, 'rb') as inf:
|
|
|
+ if idx > 0:
|
|
|
+ inf.readline() # 跳过表头
|
|
|
+ outf.write(inf.read())
|
|
|
+ os.remove(tmp_file)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # 6. 删除临时表
|
|
|
+ print(f"删除临时表: {tmp_table}")
|
|
|
+ self.odps.delete_table(tmp_table, if_exists=True)
|
|
|
+
|
|
|
+ total_time = time.time() - start_time
|
|
|
+ print(f"总耗时: {total_time:.1f}s")
|
|
|
+ print(f"完成: {output_file}")
|