Skip to content

Using flatten/unflatten

Transforming nested fields

Warning

The use case presented in this page is deprecated, but is kept to illustrate what flatten/unflatten can do. The spark_frame.nested module is much more powerful for manipulating nested data, because unlike flatten/unflatten, it does work with arrays. We recommend checking this use-case to see the spark_frame.nested module in action.

This example demonstrates how the spark_frame.transformations.flatten and unflatten spark_frame.transformations.unflatten methods can be used to make data cleaning pipeline easier with PySpark.

Let's take a sample DataFrame with our favorite example: Pokemons

>>> from spark_frame.examples.flatten_unflatten import _get_sample_pokemon_data
>>> df = _get_sample_pokemon_data()
>>> df.printSchema()
root
 |-- base_stats: struct (nullable = true)
 |    |-- Attack: long (nullable = true)
 |    |-- Defense: long (nullable = true)
 |    |-- HP: long (nullable = true)
 |    |-- Sp Attack: long (nullable = true)
 |    |-- Sp Defense: long (nullable = true)
 |    |-- Speed: long (nullable = true)
 |-- id: long (nullable = true)
 |-- name: struct (nullable = true)
 |    |-- english: string (nullable = true)
 |    |-- french: string (nullable = true)
 |-- types: array (nullable = true)
 |    |-- element: string (containsNull = true)

>>> df.show(vertical=True, truncate=False)
-RECORD 0------------------------------
 base_stats | {49, 49, 45, 65, 65, 45}
 id         | 1
 name       | {Bulbasaur, Bulbizarre}
 types      | [Grass, Poison]

Let's say we want to add a new enrich the "base_stats" struct with a new field named "Total".

Without spark-frame

Of course, we could write something in DataFrame or SQL like this:

>>> df.createOrReplaceTempView("df")
>>> new_df = df.sparkSession.sql('''
... SELECT
...   STRUCT(
...     base_stats.*,
...     base_stats.Attack + base_stats.Defense + base_stats.HP +
...     base_stats.`Sp Attack` + base_stats.`Sp Defense` + base_stats.Speed as Total
...   ) as base_stats,
...   id,
...   name,
...   types
... FROM df
... ''').show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
 base_stats | {49, 49, 45, 65, 65, 45, 318}
 id         | 1
 name       | {Bulbasaur, Bulbizarre}
 types      | [Grass, Poison]

It works, but it is a little cumbersome. Imagine how ugly the query would look like with a much bigger table, with hundreds of columns with three levels of nesting or more...

With spark-frame

Instead, we can use the spark_frame.transformations.flatten and unflatten spark_frame.transformations.unflatten methods to reduce boilerplate significantly.

>>> from spark_frame.transformations import flatten, unflatten
>>> from pyspark.sql import functions as f
>>> flat_df = flatten(df)
>>> flat_df = flat_df.withColumn("base_stats.Total",
...     f.col("`base_stats.Attack`") + f.col("`base_stats.Defense`") + f.col("`base_stats.HP`") +
...     f.col("`base_stats.Sp Attack`") + f.col("`base_stats.Sp Defense`") + f.col("`base_stats.Speed`")
... )
>>> new_df = unflatten(flat_df)
>>> new_df.show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
 base_stats | {49, 49, 45, 65, 65, 45, 318}
 id         | 1
 name       | {Bulbasaur, Bulbizarre}
 types      | [Grass, Poison]

This yield the same result, and we did not have to mention the names of the columns we did not care about. This makes pipelines much easier to maintain. If a new column is added to your source table, you don't need to update this data enrichment code to propagate it automatically. On the other hand, with the first SQL solution, you would have had to specifically add this new field to the query to propagate it.

We can even use DataFrame.transform to inline everything!

>>> df.transform(flatten).withColumn(
...     "base_stats.Total",
...     f.col("`base_stats.Attack`") + f.col("`base_stats.Defense`") + f.col("`base_stats.HP`") +
...     f.col("`base_stats.Sp Attack`") + f.col("`base_stats.Sp Defense`") + f.col("`base_stats.Speed`")
...   ).transform(unflatten).show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
 base_stats | {49, 49, 45, 65, 65, 45, 318}
 id         | 1
 name       | {Bulbasaur, Bulbizarre}
 types      | [Grass, Poison]

Update: Since version 0.0.4, the same result can be achieved with an even simpler and more powerful transformation

>>> from spark_frame import nested
>>> nested.print_schema(df)
root
 |-- base_stats.Attack: long (nullable = true)
 |-- base_stats.Defense: long (nullable = true)
 |-- base_stats.HP: long (nullable = true)
 |-- base_stats.Sp Attack: long (nullable = true)
 |-- base_stats.Sp Defense: long (nullable = true)
 |-- base_stats.Speed: long (nullable = true)
 |-- id: long (nullable = true)
 |-- name.english: string (nullable = true)
 |-- name.french: string (nullable = true)
 |-- types!: string (nullable = true)

>>> df.transform(nested.with_fields, {
...     "base_stats.Total":
...     f.col("base_stats.Attack") + f.col("base_stats.Defense") + f.col("base_stats.HP") +
...     f.col("base_stats.`Sp Attack`") + f.col("base_stats.`Sp Defense`") + f.col("base_stats.Speed")
... }).show(vertical=True, truncate=False)
-RECORD 0-----------------------------------
 base_stats | {49, 49, 45, 65, 65, 45, 318}
 id         | 1
 name       | {Bulbasaur, Bulbizarre}
 types      | [Grass, Poison]

Methods used in this example

transformations.flatten

Flatten all the struct columns of a Spark DataFrame. Nested fields names will be joined together using the specified separator

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
struct_separator str

A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots

'.'

Returns:

Type Description
DataFrame

A flattened DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame(
...         [(1, {"a": 1, "b": {"c": 1, "d": 1}})],
...         "id INT, s STRUCT<a:INT, b:STRUCT<c:INT, d:INT>>"
...      )
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s: struct (nullable = true)
 |    |-- a: integer (nullable = true)
 |    |-- b: struct (nullable = true)
 |    |    |-- c: integer (nullable = true)
 |    |    |-- d: integer (nullable = true)

>>> flatten(df).printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.a: integer (nullable = true)
 |-- s.b.c: integer (nullable = true)
 |-- s.b.d: integer (nullable = true)

>>> df = spark.createDataFrame(
...         [(1, {"a.a1": 1, "b.b1": {"c.c1": 1, "d.d1": 1}})],
...         "id INT, `s.s1` STRUCT<`a.a1`:INT, `b.b1`:STRUCT<`c.c1`:INT, `d.d1`:INT>>"
... )
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1: struct (nullable = true)
 |    |-- a.a1: integer (nullable = true)
 |    |-- b.b1: struct (nullable = true)
 |    |    |-- c.c1: integer (nullable = true)
 |    |    |-- d.d1: integer (nullable = true)

>>> flatten(df, "?").printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1?a.a1: integer (nullable = true)
 |-- s.s1?b.b1?c.c1: integer (nullable = true)
 |-- s.s1?b.b1?d.d1: integer (nullable = true)
Source code in spark_frame/transformations_impl/flatten.py
def flatten(df: DataFrame, struct_separator: str = ".") -> DataFrame:
    """Flatten all the struct columns of a Spark [DataFrame][pyspark.sql.DataFrame].
    Nested fields names will be joined together using the specified separator

    Args:
        df: A Spark DataFrame
        struct_separator: A string used to separate the structs names from their elements.
            It might be useful to change the separator when some DataFrame's column names already contain dots

    Returns:
        A flattened DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame(
        ...         [(1, {"a": 1, "b": {"c": 1, "d": 1}})],
        ...         "id INT, s STRUCT<a:INT, b:STRUCT<c:INT, d:INT>>"
        ...      )
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s: struct (nullable = true)
         |    |-- a: integer (nullable = true)
         |    |-- b: struct (nullable = true)
         |    |    |-- c: integer (nullable = true)
         |    |    |-- d: integer (nullable = true)
        <BLANKLINE>
        >>> flatten(df).printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.a: integer (nullable = true)
         |-- s.b.c: integer (nullable = true)
         |-- s.b.d: integer (nullable = true)
        <BLANKLINE>
        >>> df = spark.createDataFrame(
        ...         [(1, {"a.a1": 1, "b.b1": {"c.c1": 1, "d.d1": 1}})],
        ...         "id INT, `s.s1` STRUCT<`a.a1`:INT, `b.b1`:STRUCT<`c.c1`:INT, `d.d1`:INT>>"
        ... )
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1: struct (nullable = true)
         |    |-- a.a1: integer (nullable = true)
         |    |-- b.b1: struct (nullable = true)
         |    |    |-- c.c1: integer (nullable = true)
         |    |    |-- d.d1: integer (nullable = true)
        <BLANKLINE>
        >>> flatten(df, "?").printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1?a.a1: integer (nullable = true)
         |-- s.s1?b.b1?c.c1: integer (nullable = true)
         |-- s.s1?b.b1?d.d1: integer (nullable = true)
        <BLANKLINE>

    """
    # The idea is to recursively write a "SELECT s.b.c as `s.b.c`" for each nested column.
    cols = []

    def expand_struct(struct: StructType, col_stack: List[str]) -> None:
        for field in struct:
            if isinstance(field.dataType, StructType):
                struct_field = field.dataType
                expand_struct(struct_field, [*col_stack, field.name])
            else:
                column = f.col(".".join(quote_columns([*col_stack, field.name])))
                cols.append(column.alias(struct_separator.join([*col_stack, field.name])))

    expand_struct(df.schema, col_stack=[])
    return df.select(cols)
transformations.unflatten

Reverse of the flatten operation Nested fields names will be separated from each other using the specified separator

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
separator str

A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots

'.'

Returns:

Type Description
DataFrame

A flattened DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame([(1, 1, 1, 1)], "id INT, `s.a` INT, `s.b.c` INT, `s.b.d` INT")
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.a: integer (nullable = true)
 |-- s.b.c: integer (nullable = true)
 |-- s.b.d: integer (nullable = true)

>>> unflatten(df).printSchema()
root
 |-- id: integer (nullable = true)
 |-- s: struct (nullable = true)
 |    |-- a: integer (nullable = true)
 |    |-- b: struct (nullable = true)
 |    |    |-- c: integer (nullable = true)
 |    |    |-- d: integer (nullable = true)

>>> df = spark.createDataFrame([(1, 1, 1)], "id INT, `s.s1?a.a1` INT, `s.s1?b.b1` INT")
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1?a.a1: integer (nullable = true)
 |-- s.s1?b.b1: integer (nullable = true)

>>> unflatten(df, "?").printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1: struct (nullable = true)
 |    |-- a.a1: integer (nullable = true)
 |    |-- b.b1: integer (nullable = true)
Source code in spark_frame/transformations_impl/unflatten.py
def unflatten(df: DataFrame, separator: str = ".") -> DataFrame:
    """Reverse of the flatten operation
    Nested fields names will be separated from each other using the specified separator

    Args:
        df: A Spark DataFrame
        separator: A string used to separate the structs names from their elements.
                   It might be useful to change the separator when some DataFrame's column names already contain dots

    Returns:
        A flattened DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame([(1, 1, 1, 1)], "id INT, `s.a` INT, `s.b.c` INT, `s.b.d` INT")
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.a: integer (nullable = true)
         |-- s.b.c: integer (nullable = true)
         |-- s.b.d: integer (nullable = true)
        <BLANKLINE>
        >>> unflatten(df).printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s: struct (nullable = true)
         |    |-- a: integer (nullable = true)
         |    |-- b: struct (nullable = true)
         |    |    |-- c: integer (nullable = true)
         |    |    |-- d: integer (nullable = true)
        <BLANKLINE>
        >>> df = spark.createDataFrame([(1, 1, 1)], "id INT, `s.s1?a.a1` INT, `s.s1?b.b1` INT")
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1?a.a1: integer (nullable = true)
         |-- s.s1?b.b1: integer (nullable = true)
        <BLANKLINE>
        >>> unflatten(df, "?").printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1: struct (nullable = true)
         |    |-- a.a1: integer (nullable = true)
         |    |-- b.b1: integer (nullable = true)
        <BLANKLINE>
    """
    # The idea is to recursively write a "SELECT struct(a, struct(s.b.c, s.b.d)) as s" for each nested column.
    # There is a little twist as we don't want to rebuild the struct if all its fields are NULL, so we add a CASE WHEN

    def has_structs(df: DataFrame) -> bool:
        struct_fields = [field for field in df.schema if is_struct(field)]
        return len(struct_fields) > 0

    if has_structs(df):
        df = flatten(df)

    tree = _build_nested_struct_tree(df.columns, separator)
    cols = _build_struct_from_tree(tree, separator)
    return df.select(cols)