run_sql.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. """
  4. SQL 执行工具 - 输入 SQL 文件,输出查询结果到同目录下的 CSV
  5. 使用示例:
  6. python run_sql.py tasks/渠道效果分析/渠道再分享回流.sql
  7. python run_sql.py tasks/渠道效果分析/渠道再分享回流.sql --start 20251222 --end 20260103
  8. """
  9. import argparse
  10. from datetime import datetime, timedelta
  11. from pathlib import Path
  12. from lib.odps_module import ODPSClient
  13. def get_default_dates():
  14. """获取默认日期范围:最近 7 天(start=7天前, end=昨天)"""
  15. today = datetime.now()
  16. end_date = today - timedelta(days=1)
  17. start_date = today - timedelta(days=7)
  18. return start_date.strftime('%Y%m%d'), end_date.strftime('%Y%m%d')
  19. def parse_variables(var_list: list) -> dict:
  20. """解析变量列表为字典"""
  21. if not var_list:
  22. return {}
  23. variables = {}
  24. for item in var_list:
  25. if '=' in item:
  26. key, value = item.split('=', 1)
  27. variables[key.strip()] = value.strip()
  28. return variables
  29. def replace_variables(sql: str, variables: dict) -> str:
  30. """替换 SQL 中的 ${variable} 占位符"""
  31. for key, value in variables.items():
  32. sql = sql.replace(f'${{{key}}}', value)
  33. return sql
  34. def run_sql(sql_file: str, output_file: str = None, variables: dict = None,
  35. start: str = None, end: str = None, dry_run: bool = False):
  36. """
  37. 执行 SQL 文件并保存结果
  38. Args:
  39. sql_file: SQL 文件路径
  40. output_file: 输出文件路径(默认与 SQL 同目录同名)
  41. variables: 变量替换字典
  42. start: dt 分区起始日期
  43. end: dt 分区结束日期
  44. dry_run: 仅打印 SQL,不执行
  45. """
  46. sql_path = Path(sql_file)
  47. # 合并 start/end 到 variables
  48. if variables is None:
  49. variables = {}
  50. if start:
  51. variables['start'] = start
  52. if end:
  53. variables['end'] = end
  54. # 输出目录:SQL 同目录下的 output/;文件名:[sql前缀_]日期.csv
  55. if output_file is None:
  56. output_dir = sql_path.parent / "output"
  57. output_dir.mkdir(exist_ok=True)
  58. # SQL 文件名作为前缀
  59. sql_stem = sql_path.stem # 去掉 .sql 后缀
  60. prefix = f"{sql_stem}_"
  61. if start and end:
  62. output_file = output_dir / f"{prefix}{start}_{end}.csv"
  63. elif start:
  64. output_file = output_dir / f"{prefix}{start}.csv"
  65. else:
  66. output_file = output_dir / f"{prefix}result.csv"
  67. else:
  68. output_file = Path(output_file)
  69. # 读取 SQL
  70. with open(sql_path, 'r', encoding='utf-8') as f:
  71. sql = f.read()
  72. # 变量替换
  73. if variables:
  74. sql = replace_variables(sql, variables)
  75. # Dry run 模式
  76. if dry_run:
  77. print("=" * 50)
  78. print("SQL 预览 (dry-run 模式)")
  79. print("=" * 50)
  80. print(sql)
  81. print("=" * 50)
  82. print(f"输出文件: {output_file}")
  83. return
  84. # 执行 SQL
  85. print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始执行: {sql_path.name}")
  86. odps_client = ODPSClient()
  87. odps_client.execute_sql_result_save_file(sql, str(output_file))
  88. print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 完成,结果保存至: {output_file}")
  89. def main():
  90. parser = argparse.ArgumentParser(
  91. description='执行 SQL 文件并输出结果',
  92. formatter_class=argparse.RawDescriptionHelpFormatter,
  93. epilog="""
  94. 示例:
  95. python run_sql.py tasks/渠道效果分析/渠道再分享回流.sql
  96. python run_sql.py tasks/渠道效果分析/渠道再分享回流.sql --start 20251222 --end 20260103
  97. python run_sql.py tasks/渠道效果分析/渠道再分享回流.sql --dry-run
  98. """
  99. )
  100. parser.add_argument('sql_file', type=str, help='SQL 文件路径')
  101. parser.add_argument('--start', type=str, help='dt 分区起始日期,替换 ${start}')
  102. parser.add_argument('--end', type=str, help='dt 分区结束日期,替换 ${end}')
  103. parser.add_argument('-o', '--output', type=str, help='自定义输出路径')
  104. parser.add_argument('--vars', nargs='*', metavar='KEY=VALUE', help='额外变量,如: apptype=36')
  105. parser.add_argument('--dry-run', action='store_true', help='仅打印 SQL,不执行')
  106. args = parser.parse_args()
  107. # 解析变量
  108. variables = parse_variables(args.vars)
  109. # 默认日期
  110. start = args.start
  111. end = args.end
  112. if start is None or end is None:
  113. default_start, default_end = get_default_dates()
  114. start = start or default_start
  115. end = end or default_end
  116. print(f"使用默认日期范围: {start} ~ {end}")
  117. # 执行
  118. run_sql(
  119. sql_file=args.sql_file,
  120. output_file=args.output,
  121. variables=variables,
  122. start=start,
  123. end=end,
  124. dry_run=args.dry_run
  125. )
  126. if __name__ == "__main__":
  127. main()