Using flatten/unflatten
Transforming nested fields
Warning
The use case presented in this page is deprecated, but is kept to illustrate what flatten/unflatten can do. The spark_frame.nested module is much more powerful for manipulating nested data, because unlike flatten/unflatten, it does work with arrays. We recommend checking this use-case to see the spark_frame.nested module in action.
This example demonstrates how the spark_frame.transformations.flatten
and unflatten spark_frame.transformations.unflatten
methods can be used to
make data cleaning pipeline easier with PySpark.
Let's take a sample DataFrame with our favorite example: Pokemons
>>> from spark_frame.examples.flatten_unflatten import _get_sample_pokemon_data
>>> df = _get_sample_pokemon_data()
>>> df.printSchema()
root
|-- base_stats: struct (nullable = true)
| |-- Attack: long (nullable = true)
| |-- Defense: long (nullable = true)
| |-- HP: long (nullable = true)
| |-- Sp Attack: long (nullable = true)
| |-- Sp Defense: long (nullable = true)
| |-- Speed: long (nullable = true)
|-- id: long (nullable = true)
|-- name: struct (nullable = true)
| |-- english: string (nullable = true)
| |-- french: string (nullable = true)
|-- types: array (nullable = true)
| |-- element: string (containsNull = true)
>>> df.show(vertical=True, truncate=False)
-RECORD 0------------------------------
base_stats | {49, 49, 45, 65, 65, 45}
id | 1
name | {Bulbasaur, Bulbizarre}
types | [Grass, Poison]
Let's say we want to add a new enrich the "base_stats" struct with a new field named "Total".
Without spark-frame
Of course, we could write something in DataFrame or SQL like this:
>>> df.createOrReplaceTempView("df")
>>> new_df = df.sparkSession.sql('''
... SELECT
... STRUCT(
... base_stats.*,
... base_stats.Attack + base_stats.Defense + base_stats.HP +
... base_stats.`Sp Attack` + base_stats.`Sp Defense` + base_stats.Speed as Total
... ) as base_stats,
... id,
... name,
... types
... FROM df
... ''').show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
base_stats | {49, 49, 45, 65, 65, 45, 318}
id | 1
name | {Bulbasaur, Bulbizarre}
types | [Grass, Poison]
It works, but it is a little cumbersome. Imagine how ugly the query would look like with a much bigger table, with hundreds of columns with three levels of nesting or more...
With spark-frame
Instead, we can use the spark_frame.transformations.flatten
and unflatten spark_frame.transformations.unflatten
methods to reduce
boilerplate significantly.
>>> from spark_frame.transformations import flatten, unflatten
>>> from pyspark.sql import functions as f
>>> flat_df = flatten(df)
>>> flat_df = flat_df.withColumn("base_stats.Total",
... f.col("`base_stats.Attack`") + f.col("`base_stats.Defense`") + f.col("`base_stats.HP`") +
... f.col("`base_stats.Sp Attack`") + f.col("`base_stats.Sp Defense`") + f.col("`base_stats.Speed`")
... )
>>> new_df = unflatten(flat_df)
>>> new_df.show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
base_stats | {49, 49, 45, 65, 65, 45, 318}
id | 1
name | {Bulbasaur, Bulbizarre}
types | [Grass, Poison]
This yield the same result, and we did not have to mention the names of the columns we did not care about. This makes pipelines much easier to maintain. If a new column is added to your source table, you don't need to update this data enrichment code to propagate it automatically. On the other hand, with the first SQL solution, you would have had to specifically add this new field to the query to propagate it.
We can even use DataFrame.transform to inline everything!
>>> df.transform(flatten).withColumn(
... "base_stats.Total",
... f.col("`base_stats.Attack`") + f.col("`base_stats.Defense`") + f.col("`base_stats.HP`") +
... f.col("`base_stats.Sp Attack`") + f.col("`base_stats.Sp Defense`") + f.col("`base_stats.Speed`")
... ).transform(unflatten).show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
base_stats | {49, 49, 45, 65, 65, 45, 318}
id | 1
name | {Bulbasaur, Bulbizarre}
types | [Grass, Poison]
Update: Since version 0.0.4, the same result can be achieved with an even simpler and more powerful transformation
>>> from spark_frame import nested
>>> nested.print_schema(df)
root
|-- base_stats.Attack: long (nullable = true)
|-- base_stats.Defense: long (nullable = true)
|-- base_stats.HP: long (nullable = true)
|-- base_stats.Sp Attack: long (nullable = true)
|-- base_stats.Sp Defense: long (nullable = true)
|-- base_stats.Speed: long (nullable = true)
|-- id: long (nullable = true)
|-- name.english: string (nullable = true)
|-- name.french: string (nullable = true)
|-- types!: string (nullable = true)
>>> df.transform(nested.with_fields, {
... "base_stats.Total":
... f.col("base_stats.Attack") + f.col("base_stats.Defense") + f.col("base_stats.HP") +
... f.col("base_stats.`Sp Attack`") + f.col("base_stats.`Sp Defense`") + f.col("base_stats.Speed")
... }).show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
base_stats | {49, 49, 45, 65, 65, 45, 318}
id | 1
name | {Bulbasaur, Bulbizarre}
types | [Grass, Poison]
Info
This example uses data taken from https://raw.githubusercontent.com/fanzeyi/pokemon.json/master/pokedex.json.
Methods used in this example
transformations.flatten
Flatten all the struct columns of a Spark DataFrame. Nested fields names will be joined together using the specified separator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
DataFrame
|
A Spark DataFrame |
required |
struct_separator |
str
|
A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots |
'.'
|
Returns:
Type | Description |
---|---|
DataFrame
|
A flattened DataFrame |
Examples:
>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame(
... [(1, {"a": 1, "b": {"c": 1, "d": 1}})],
... "id INT, s STRUCT<a:INT, b:STRUCT<c:INT, d:INT>>"
... )
>>> df.printSchema()
root
|-- id: integer (nullable = true)
|-- s: struct (nullable = true)
| |-- a: integer (nullable = true)
| |-- b: struct (nullable = true)
| | |-- c: integer (nullable = true)
| | |-- d: integer (nullable = true)
>>> flatten(df).printSchema()
root
|-- id: integer (nullable = true)
|-- s.a: integer (nullable = true)
|-- s.b.c: integer (nullable = true)
|-- s.b.d: integer (nullable = true)
>>> df = spark.createDataFrame(
... [(1, {"a.a1": 1, "b.b1": {"c.c1": 1, "d.d1": 1}})],
... "id INT, `s.s1` STRUCT<`a.a1`:INT, `b.b1`:STRUCT<`c.c1`:INT, `d.d1`:INT>>"
... )
>>> df.printSchema()
root
|-- id: integer (nullable = true)
|-- s.s1: struct (nullable = true)
| |-- a.a1: integer (nullable = true)
| |-- b.b1: struct (nullable = true)
| | |-- c.c1: integer (nullable = true)
| | |-- d.d1: integer (nullable = true)
>>> flatten(df, "?").printSchema()
root
|-- id: integer (nullable = true)
|-- s.s1?a.a1: integer (nullable = true)
|-- s.s1?b.b1?c.c1: integer (nullable = true)
|-- s.s1?b.b1?d.d1: integer (nullable = true)
Source code in spark_frame/transformations_impl/flatten.py
transformations.unflatten
Reverse of the flatten operation Nested fields names will be separated from each other using the specified separator
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
DataFrame
|
A Spark DataFrame |
required |
separator |
str
|
A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots |
'.'
|
Returns:
Type | Description |
---|---|
DataFrame
|
A flattened DataFrame |
Examples:
>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame([(1, 1, 1, 1)], "id INT, `s.a` INT, `s.b.c` INT, `s.b.d` INT")
>>> df.printSchema()
root
|-- id: integer (nullable = true)
|-- s.a: integer (nullable = true)
|-- s.b.c: integer (nullable = true)
|-- s.b.d: integer (nullable = true)
>>> unflatten(df).printSchema()
root
|-- id: integer (nullable = true)
|-- s: struct (nullable = true)
| |-- a: integer (nullable = true)
| |-- b: struct (nullable = true)
| | |-- c: integer (nullable = true)
| | |-- d: integer (nullable = true)
>>> df = spark.createDataFrame([(1, 1, 1)], "id INT, `s.s1?a.a1` INT, `s.s1?b.b1` INT")
>>> df.printSchema()
root
|-- id: integer (nullable = true)
|-- s.s1?a.a1: integer (nullable = true)
|-- s.s1?b.b1: integer (nullable = true)
>>> unflatten(df, "?").printSchema()
root
|-- id: integer (nullable = true)
|-- s.s1: struct (nullable = true)
| |-- a.a1: integer (nullable = true)
| |-- b.b1: integer (nullable = true)