PySparkでMySQLからのデータ取得&集計方法

MySQLに対してSQLでよくやるようなデータの取得や集計などをPySparkのDataFrameだとどうやるのか調べてみましたので、備忘録として残しておきたいと思います。
検証環境は以前紹介したDockerではじめるPySparkをベースにDockerで環境を構築しいます。
こういった検証にDockerはすごく便利でいいですね

環境

  • PySpark 2.2
  • MySQL5.7

データはMySQLの公式でサンプルとして提供されているworldデータベースを利用します。

環境の構築

利用するDockerイメージ

以下の通りdocker-compose.ymlを作成します。

version: '2'
services:
 pyspark:
    image: cloudfish/pyspark-notebook
    volumes:
       - LOCAL_PATH:/home/jovyan/work
    ports:
      - "8888:8888"
    command: bash -c "start-notebook.sh --NotebookApp.token=''"
    links:
      - dbserver
    environment:
      GRANT_SUDO: "yes"
  dbserver:
    image: kakakakakku/mysql-57-world-database
    environment:
      MYSQL_ALLOW_EMPTY_PASSWORD: "yes"
  phpmyadmi:
    image: phpmyadmin/phpmyadmin
    ports: 
      - "18080:80"
    links:
      - "dbserver"
    environment:
      PMA_HOST: dbserver
      PMA_USER: root 
      PMA_PASSWORD: ""

Docker起動

docker-compose up

Jupyter Notebook画面確認
http://localhost:8888
f:id:cloudfish:20180727134630p:plain
phpmyadmin画面確認
http://localhost:18080
f:id:cloudfish:20180730141054p:plain

PySparkの実行確認

早速サンプルデータベースで実行確認を進めていきます。worldデータベースは以下のようなテーブルが含まれています。
これらのテーブルを使ってデータを取得してみたいと思います。

Tables_in_world
city
country
countrylanguage

画面右端のNewボタンをクリックしPython3を選択し、開いた画面で以下を入力していきます。
f:id:cloudfish:20180727134709p:plain

以下のコードはコードセルごとに入力してください。入力後Shift + Enterでコードが実行されます。

Sparkの初期化処理

from pyspark.sql import SQLContext, Row
from pyspark import SparkContext

sc = SparkContext("local", "First App")

※2回実行するとエラーになります。

JDBC接続処理

JDBCに接続しています。

sqlContext = SQLContext(sc)
jdbc_url="jdbc:mysql://dbserver/mysql"
driver_class="com.mysql.jdbc.Driver"

DB_USER="root"
DB_PASSWORD=""

def load_dataframe(table):
  df=sqlContext.read.format("jdbc").options(
    url =jdbc_url,
    driver=driver_class,
   dbtable=table,
   user=DB_USER,
    password=DB_PASSWORD
  ).load()
  return df

データ取得処理

データをDataFrameに取得します
テーブル指定でデータ取得する場合(countryテーブル、cityテーブルを取得)

df_country = load_dataframe("world.country")
# 実行されるSQL:SELECT * FROM world.country WHERE 1=0
df_city = load_dataframe("world.city")
# 実行されるSQL:SELECT * FROM world.city WHERE 1=0

SQLを指定してデータ取得する場合(cityテーブルの国コードがJPNのものだけを取得)

df_city_japan = load_dataframe("(select * from world.city where CountryCode='JPN') city_japan")
# 実行されるSQL:SELECT * FROM (select * from world.city where CountryCode='JPN') city_japan WHERE 1=0

MySQL側でどのようなSQLが流れるのか見てみたところ、テーブルにセットした内容がFROM句の後にセットされるようです。

カラム指定

Nameカラムを表示

df_country.select("Name").show()

条件検索

国名がJapanのデータを抽出。

df_country.filter(df_country["Name"] == "Japan").show()
df_country.where(df_country["Name"] == "Japan").select("Code","GNP").show()

isNull

独立年がNullのデータを抽出

df_country.where(df_country["IndepYear"].isNull()).show()

like

国名がJで始まるデータを抽出

df_country.where(df_country["Name"].like("J%")).show()

Case When式

人口が100000人より大きい場合は「Big」、小さい場合は「Small」を表示

from pyspark.sql import functions as F
df_country.select(df_country["Name"], F.when(df_country["Population"] > 100000,"Big").otherwise("Small").alias("CountryDiv")).show()

substr

国名を3文字切り出して表示

df_country.select(df_country["Name"].substr(1,3)).show()

limit

5件のみ抽出

df_country.limit(5).show()

Join

countryテーブルとcityテーブルを結合し国名がJapanのものを抽出

df_join = df_country.alias('country').join(df_city.alias('city'),(df_city["countryCode"] == df_country["Code"]) & (df_country["Name"]=="Japan")).show()

OrderBy

GNPが高い国順に表示

df_country.orderBy("GNP" , ascending=False).select("Code","Name","GNP").show()

GroupBy

国ごとにグループ化し、cityの数、人口の平均と合計を集計

from pyspark.sql import functions as F

df_city.groupBy("countryCode") \
 .agg( \
    F.count(df_city["Name"]).alias("total_count"), \
    F.avg(df_city["Population"]).alias("avg_population"), \
    F.sum(df_city["Population"]).alias("sum_population") \
).show()

まとめ

書き方は少し慣れる必要がありますが、かなりSQLに近いイメージでデータ取得が可能なことが分かりました。
今回はDBに対して実行しましたが、ファイルに対しても同様に実行可能です。