spark_sql.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. from pyspark.sql import SparkSession
  4. try:
  5. # for python 2
  6. reload(sys)
  7. sys.setdefaultencoding('utf8')
  8. except:
  9. # python 3 not needed
  10. pass
  11. if __name__ == '__main__':
  12. spark = SparkSession.builder\
  13. .appName("spark sql")\
  14. .config("spark.sql.broadcastTimeout", 20 * 60)\
  15. .config("spark.sql.crossJoin.enabled", True)\
  16. .getOrCreate()
  17. tableName = "mc_test_table"
  18. ptTableName = "mc_test_pt_table"
  19. data = [i for i in range(0, 100)]
  20. # Drop Create
  21. spark.sql("DROP TABLE IF EXISTS %s" % tableName)
  22. spark.sql("DROP TABLE IF EXISTS %s" % ptTableName)
  23. spark.sql("CREATE TABLE %s (name STRING, num BIGINT)" % tableName)
  24. spark.sql("CREATE TABLE %s (name STRING, num BIGINT) PARTITIONED BY (pt1 STRING, pt2 STRING)" % ptTableName)
  25. df = spark.sparkContext.parallelize(data, 2).map(lambda s: ("name-%s" % s, s)).toDF("name: string, num: int")
  26. pt_df = spark.sparkContext.parallelize(data, 2).map(lambda s: ("name-%s" % s, s, "2018", "0601")).toDF("name: string, num: int, pt1: string, pt2: string")
  27. # 写 普通表
  28. df.write.insertInto(tableName) # insertInto语义
  29. df.writeTo(tableName).overwritePartitions() # insertOverwrite use datasourcev2
  30. # 写 分区表
  31. # DataFrameWriter 无法指定分区写入 需要通过临时表再用SQL写入特定分区
  32. df.createOrReplaceTempView("%s_tmp_view" % ptTableName)
  33. spark.sql("insert into table %s partition (pt1='2018', pt2='0601') select * from %s_tmp_view" % (ptTableName, ptTableName))
  34. spark.sql("insert overwrite table %s partition (pt1='2018', pt2='0601') select * from %s_tmp_view" % (ptTableName, ptTableName))
  35. pt_df.write.insertInto(ptTableName) # 动态分区 insertInto语义
  36. pt_df.write.insertInto(ptTableName, True) # 动态分区 insertOverwrite语义
  37. # 读 普通表
  38. rdf = spark.sql("select name, num from %s" % tableName)
  39. print("rdf count, %s\n" % rdf.count())
  40. rdf.printSchema()
  41. # 读 分区表
  42. rptdf = spark.sql("select name, num, pt1, pt2 from %s where pt1 = '2018' and pt2 = '0601'" % ptTableName)
  43. print("rptdf count, %s" % (rptdf.count()))
  44. rptdf.printSchema()