Skip to content

spark_frame.transformations

Unlike those in spark_frame.functions, the methods in this module all take at least one DataFrame as argument and return a new transformed DataFrame. These methods generally offer higher order transformation that requires to inspect the schema or event the content of the input DataFrame(s) before generating the next transformation. Those are typically generic operations that cannot be implemented with one single SQL query.

Tip

Since Spark 3.3.0, all transformations can be inlined using DataFrame.transform, like this:

df.transform(flatten).withColumn(
    "base_stats.Total",
    f.col("`base_stats.Attack`") + f.col("`base_stats.Defense`") + f.col("`base_stats.HP`") +
    f.col("`base_stats.Sp Attack`") + f.col("`base_stats.Sp Defense`") + f.col("`base_stats.Speed`")
).transform(unflatten).show(vertical=True, truncate=False)
This example is taken

analyze(df: DataFrame, group_by: Optional[Union[str, List[str]]] = None, group_alias: str = 'group', _aggs: Optional[List[Callable[[str, StructField, int], Column]]] = None) -> DataFrame

Analyze a DataFrame by computing various stats for each column.

By default, it returns a DataFrame with one row per column and the following columns (but the columns computed can be customized, see the Customization section below):

  • column_number: Number of the column (useful for sorting)
  • column_name: Name of the column
  • column_type: Type of the column
  • count: Number of rows in the column, it is equal to the number of rows in the table, except for columns nested inside arrays for which it may be different
  • count_distinct: Number of distinct values
  • count_null: Number of null values
  • min: smallest value
  • max: largest value
Implementation details
  • Structs are flattened with a . after their name.
  • Arrays are unnested with a ! character after their name, which is why they may have a different count.
  • Null values are not counted in the count_distinct column.

Limitation: Map type is not supported

This method currently does not work on columns of type Map. A possible workaround is to use spark_frame.transformations.convert_all_maps_to_arrays before using it.

Grouping

With the group_by option, users can specify one or multiple columns for which the statistics will be grouped. If this option is used, an extra column "group" of type struct will be added to output DataFrame. See the examples below.

Limitation: group_by only works on non-repeated fields

Currently, the group_by option only works with non-repeated fields. Using it on repeated fields will lead to an unspecified error.

Customization

By default, this method will compute for each column the aggregations listed in spark_frame.transformation_impl.analyze.default_aggs, but users can change this and even add their own custom aggregation by passing the argument _agg, a list of aggregation functions with the following signature: (col: str, schema_field: StructField, col_num: int) -> Column

Examples of aggregation methods can be found in the module spark_frame.transformation_impl.analyze_aggs

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
group_by Optional[Union[str, List[str]]]

A list of column names on which the aggregations will be grouped

None
group_alias str

The alias to use for the struct column that will contain the group_by columns, if any.

'group'
_aggs Optional[List[Callable[[str, StructField, int], Column]]]

A list of aggregation to override the default aggregation made by the function

None

Returns:

Type Description
DataFrame

A new DataFrame containing descriptive statistics about the input DataFrame

Examples:

>>> from spark_frame.transformations_impl.analyze import __get_test_df
>>> df = __get_test_df()
>>> df.show()
+---+----------+---------------+------------+
| id|      name|          types|   evolution|
+---+----------+---------------+------------+
|  1| Bulbasaur|[Grass, Poison]|{true, NULL}|
|  2|   Ivysaur|[Grass, Poison]|   {true, 1}|
|  3|  Venusaur|[Grass, Poison]|  {false, 2}|
|  4|Charmander|         [Fire]|{true, NULL}|
|  5|Charmeleon|         [Fire]|   {true, 4}|
|  6| Charizard| [Fire, Flying]|  {false, 5}|
|  7|  Squirtle|        [Water]|{true, NULL}|
|  8| Wartortle|        [Water]|   {true, 7}|
|  9| Blastoise|        [Water]|  {false, 8}|
+---+----------+---------------+------------+
>>> analyzed_df = analyze(df)
Analyzing 5 columns ...
>>> analyzed_df.show(truncate=False)  # noqa: E501
+-------------+----------------------+-----------+-----+--------------+----------+---------+---------+
|column_number|column_name           |column_type|count|count_distinct|count_null|min      |max      |
+-------------+----------------------+-----------+-----+--------------+----------+---------+---------+
|0            |id                    |INTEGER    |9    |9             |0         |1        |9        |
|1            |name                  |STRING     |9    |9             |0         |Blastoise|Wartortle|
|2            |types!                |STRING     |13   |5             |0         |Fire     |Water    |
|3            |evolution.can_evolve  |BOOLEAN    |9    |2             |0         |false    |true     |
|4            |evolution.evolves_from|INTEGER    |9    |6             |3         |1        |8        |
+-------------+----------------------+-----------+-----+--------------+----------+---------+---------+

Analyze a DataFrame with custom aggregation methods

Custom aggregation methods can be defined as method functions that take three arguments: - col: the name of the Column that will be analyzed - struct_field: a Column name - col_num: the number of the column

>>> from spark_frame.transformations_impl import analyze_aggs
>>> from pyspark.sql.types import IntegerType
>>> def total(col: str, struct_field: StructField, _: int) -> Column:
...     if struct_field.dataType == IntegerType():
...         return f.sum(col).alias("total")
...     else:
...         return f.lit(None).alias("total")
>>> aggs = [
...     analyze_aggs.column_number,
...     analyze_aggs.column_name,
...     analyze_aggs.count,
...     analyze_aggs.count_distinct,
...     analyze_aggs.count_null,
...     total
... ]
>>> analyzed_df = analyze(df, _aggs=aggs)
Analyzing 5 columns ...
>>> analyzed_df.show(truncate=False)  # noqa: E501
+-------------+----------------------+-----+--------------+----------+-----+
|column_number|column_name           |count|count_distinct|count_null|total|
+-------------+----------------------+-----+--------------+----------+-----+
|0            |id                    |9    |9             |0         |45   |
|1            |name                  |9    |9             |0         |NULL |
|2            |types!                |13   |5             |0         |NULL |
|3            |evolution.can_evolve  |9    |2             |0         |NULL |
|4            |evolution.evolves_from|9    |6             |3         |27   |
+-------------+----------------------+-----+--------------+----------+-----+

Analyze a DataFrame grouped by a specific column

Use the group_by to group the result by one or multiple columns

>>> df = __get_test_df().withColumn("main_type", f.expr("types[0]"))
>>> df.show()
+---+----------+---------------+------------+---------+
| id|      name|          types|   evolution|main_type|
+---+----------+---------------+------------+---------+
|  1| Bulbasaur|[Grass, Poison]|{true, NULL}|    Grass|
|  2|   Ivysaur|[Grass, Poison]|   {true, 1}|    Grass|
|  3|  Venusaur|[Grass, Poison]|  {false, 2}|    Grass|
|  4|Charmander|         [Fire]|{true, NULL}|     Fire|
|  5|Charmeleon|         [Fire]|   {true, 4}|     Fire|
|  6| Charizard| [Fire, Flying]|  {false, 5}|     Fire|
|  7|  Squirtle|        [Water]|{true, NULL}|    Water|
|  8| Wartortle|        [Water]|   {true, 7}|    Water|
|  9| Blastoise|        [Water]|  {false, 8}|    Water|
+---+----------+---------------+------------+---------+

>>> analyzed_df = analyze(df, group_by="main_type", _aggs=aggs)
Analyzing 5 columns ...
>>> analyzed_df.orderBy("`group`.main_type", "column_number").show(truncate=False)
+-------+-------------+----------------------+-----+--------------+----------+-----+
|group  |column_number|column_name           |count|count_distinct|count_null|total|
+-------+-------------+----------------------+-----+--------------+----------+-----+
|{Fire} |0            |id                    |3    |3             |0         |15   |
|{Fire} |1            |name                  |3    |3             |0         |NULL |
|{Fire} |2            |types!                |4    |2             |0         |NULL |
|{Fire} |3            |evolution.can_evolve  |3    |2             |0         |NULL |
|{Fire} |4            |evolution.evolves_from|3    |2             |1         |9    |
|{Grass}|0            |id                    |3    |3             |0         |6    |
|{Grass}|1            |name                  |3    |3             |0         |NULL |
|{Grass}|2            |types!                |6    |2             |0         |NULL |
|{Grass}|3            |evolution.can_evolve  |3    |2             |0         |NULL |
|{Grass}|4            |evolution.evolves_from|3    |2             |1         |3    |
|{Water}|0            |id                    |3    |3             |0         |24   |
|{Water}|1            |name                  |3    |3             |0         |NULL |
|{Water}|2            |types!                |3    |1             |0         |NULL |
|{Water}|3            |evolution.can_evolve  |3    |2             |0         |NULL |
|{Water}|4            |evolution.evolves_from|3    |2             |1         |15   |
+-------+-------------+----------------------+-----+--------------+----------+-----+
Source code in spark_frame/transformations_impl/analyze.py
def analyze(
    df: DataFrame,
    group_by: Optional[Union[str, List[str]]] = None,
    group_alias: str = "group",
    _aggs: Optional[List[Callable[[str, StructField, int], Column]]] = None,
) -> DataFrame:
    """Analyze a DataFrame by computing various stats for each column.

    By default, it returns a DataFrame with one row per column and the following columns
    (but the columns computed can be customized, see the Customization section below):

    - `column_number`: Number of the column (useful for sorting)
    - `column_name`: Name of the column
    - `column_type`: Type of the column
    - `count`: Number of rows in the column, it is equal to the number of rows in the table, except for columns nested
      `inside` arrays for which it may be different
    - `count_distinct`: Number of distinct values
    - `count_null`: Number of null values
    - `min`: smallest value
    - `max`: largest value

    Implementation details
    ----------------------
    - Structs are flattened with a `.` after their name.
    - Arrays are unnested with a `!` character after their name, which is why they may have a different count.
    - Null values are not counted in the count_distinct column.

    !!! warning "Limitation: Map type is not supported"
        This method currently does not work on columns of type Map.
        A possible workaround is to use [`spark_frame.transformations.convert_all_maps_to_arrays`]
        [spark_frame.transformations_impl.convert_all_maps_to_arrays.convert_all_maps_to_arrays]
        before using it.

    Grouping
    --------
    With the `group_by` option, users can specify one or multiple columns for which the statistics will be grouped.
    If this option is used, an extra column "group" of type struct will be added to output DataFrame.
    See the examples below.

    !!! warning "Limitation: group_by only works on non-repeated fields"
        Currently, the `group_by` option only works with non-repeated fields.
        Using it on repeated fields will lead to an unspecified error.

    Customization
    -------------
    By default, this method will compute for each column the aggregations listed in
    `spark_frame.transformation_impl.analyze.default_aggs`, but users can change this and even add their
    own custom aggregation by passing the argument `_agg`, a list of aggregation functions with the following
    signature: `(col: str, schema_field: StructField, col_num: int) -> Column`

    Examples of aggregation methods can be found in the module `spark_frame.transformation_impl.analyze_aggs`

    Args:
        df: A Spark DataFrame
        group_by: A list of column names on which the aggregations will be grouped
        group_alias: The alias to use for the struct column that will contain the `group_by` columns, if any.
        _aggs: A list of aggregation to override the default aggregation made by the function

    Returns:
        A new DataFrame containing descriptive statistics about the input DataFrame

    Examples:
        >>> from spark_frame.transformations_impl.analyze import __get_test_df

        >>> df = __get_test_df()
        >>> df.show()
        +---+----------+---------------+------------+
        | id|      name|          types|   evolution|
        +---+----------+---------------+------------+
        |  1| Bulbasaur|[Grass, Poison]|{true, NULL}|
        |  2|   Ivysaur|[Grass, Poison]|   {true, 1}|
        |  3|  Venusaur|[Grass, Poison]|  {false, 2}|
        |  4|Charmander|         [Fire]|{true, NULL}|
        |  5|Charmeleon|         [Fire]|   {true, 4}|
        |  6| Charizard| [Fire, Flying]|  {false, 5}|
        |  7|  Squirtle|        [Water]|{true, NULL}|
        |  8| Wartortle|        [Water]|   {true, 7}|
        |  9| Blastoise|        [Water]|  {false, 8}|
        +---+----------+---------------+------------+
        <BLANKLINE>

        >>> analyzed_df = analyze(df)
        Analyzing 5 columns ...
        >>> analyzed_df.show(truncate=False)  # noqa: E501
        +-------------+----------------------+-----------+-----+--------------+----------+---------+---------+
        |column_number|column_name           |column_type|count|count_distinct|count_null|min      |max      |
        +-------------+----------------------+-----------+-----+--------------+----------+---------+---------+
        |0            |id                    |INTEGER    |9    |9             |0         |1        |9        |
        |1            |name                  |STRING     |9    |9             |0         |Blastoise|Wartortle|
        |2            |types!                |STRING     |13   |5             |0         |Fire     |Water    |
        |3            |evolution.can_evolve  |BOOLEAN    |9    |2             |0         |false    |true     |
        |4            |evolution.evolves_from|INTEGER    |9    |6             |3         |1        |8        |
        +-------------+----------------------+-----------+-----+--------------+----------+---------+---------+
        <BLANKLINE>

    Examples: Analyze a DataFrame with custom aggregation methods
        Custom aggregation methods can be defined as method functions that take three arguments:
            - `col`: the name of the Column that will be analyzed
            - `struct_field`: a Column name
            - `col_num`: the number of the column
        >>> from spark_frame.transformations_impl import analyze_aggs
        >>> from pyspark.sql.types import IntegerType
        >>> def total(col: str, struct_field: StructField, _: int) -> Column:
        ...     if struct_field.dataType == IntegerType():
        ...         return f.sum(col).alias("total")
        ...     else:
        ...         return f.lit(None).alias("total")
        >>> aggs = [
        ...     analyze_aggs.column_number,
        ...     analyze_aggs.column_name,
        ...     analyze_aggs.count,
        ...     analyze_aggs.count_distinct,
        ...     analyze_aggs.count_null,
        ...     total
        ... ]
        >>> analyzed_df = analyze(df, _aggs=aggs)
        Analyzing 5 columns ...
        >>> analyzed_df.show(truncate=False)  # noqa: E501
        +-------------+----------------------+-----+--------------+----------+-----+
        |column_number|column_name           |count|count_distinct|count_null|total|
        +-------------+----------------------+-----+--------------+----------+-----+
        |0            |id                    |9    |9             |0         |45   |
        |1            |name                  |9    |9             |0         |NULL |
        |2            |types!                |13   |5             |0         |NULL |
        |3            |evolution.can_evolve  |9    |2             |0         |NULL |
        |4            |evolution.evolves_from|9    |6             |3         |27   |
        +-------------+----------------------+-----+--------------+----------+-----+
        <BLANKLINE>

    Examples: Analyze a DataFrame grouped by a specific column
        Use the `group_by` to group the result by one or multiple columns
        >>> df = __get_test_df().withColumn("main_type", f.expr("types[0]"))
        >>> df.show()
        +---+----------+---------------+------------+---------+
        | id|      name|          types|   evolution|main_type|
        +---+----------+---------------+------------+---------+
        |  1| Bulbasaur|[Grass, Poison]|{true, NULL}|    Grass|
        |  2|   Ivysaur|[Grass, Poison]|   {true, 1}|    Grass|
        |  3|  Venusaur|[Grass, Poison]|  {false, 2}|    Grass|
        |  4|Charmander|         [Fire]|{true, NULL}|     Fire|
        |  5|Charmeleon|         [Fire]|   {true, 4}|     Fire|
        |  6| Charizard| [Fire, Flying]|  {false, 5}|     Fire|
        |  7|  Squirtle|        [Water]|{true, NULL}|    Water|
        |  8| Wartortle|        [Water]|   {true, 7}|    Water|
        |  9| Blastoise|        [Water]|  {false, 8}|    Water|
        +---+----------+---------------+------------+---------+
        <BLANKLINE>
        >>> analyzed_df = analyze(df, group_by="main_type", _aggs=aggs)
        Analyzing 5 columns ...
        >>> analyzed_df.orderBy("`group`.main_type", "column_number").show(truncate=False)
        +-------+-------------+----------------------+-----+--------------+----------+-----+
        |group  |column_number|column_name           |count|count_distinct|count_null|total|
        +-------+-------------+----------------------+-----+--------------+----------+-----+
        |{Fire} |0            |id                    |3    |3             |0         |15   |
        |{Fire} |1            |name                  |3    |3             |0         |NULL |
        |{Fire} |2            |types!                |4    |2             |0         |NULL |
        |{Fire} |3            |evolution.can_evolve  |3    |2             |0         |NULL |
        |{Fire} |4            |evolution.evolves_from|3    |2             |1         |9    |
        |{Grass}|0            |id                    |3    |3             |0         |6    |
        |{Grass}|1            |name                  |3    |3             |0         |NULL |
        |{Grass}|2            |types!                |6    |2             |0         |NULL |
        |{Grass}|3            |evolution.can_evolve  |3    |2             |0         |NULL |
        |{Grass}|4            |evolution.evolves_from|3    |2             |1         |3    |
        |{Water}|0            |id                    |3    |3             |0         |24   |
        |{Water}|1            |name                  |3    |3             |0         |NULL |
        |{Water}|2            |types!                |3    |1             |0         |NULL |
        |{Water}|3            |evolution.can_evolve  |3    |2             |0         |NULL |
        |{Water}|4            |evolution.evolves_from|3    |2             |1         |15   |
        +-------+-------------+----------------------+-----+--------------+----------+-----+
        <BLANKLINE>
    """
    if _aggs is None:
        _aggs = default_aggs
    if group_by is None:
        group_by = []
    if isinstance(group_by, str):
        group_by = [group_by]

    flat_fields = nested.fields(df)
    fields_to_drop = [field for field in flat_fields if is_sub_field_or_equal_to_any(field, group_by)]
    nb_cols = len(flat_fields) - len(fields_to_drop)
    print(f"Analyzing {nb_cols} columns ...")

    if len(group_by) > 0:
        df = df.withColumn(group_alias, f.struct(*group_by))
        group = [group_alias]
    else:
        group = []

    flattened_dfs = unnest_all_fields(df, keep_columns=group)
    index_by_field = {field: index for index, field in enumerate(flat_fields)}
    analyzed_dfs = [
        _analyze_flat_df(flat_df.drop(*fields_to_drop), index_by_field, group_by=group, aggs=_aggs)
        for flat_df in flattened_dfs.values()
    ]

    union_df = union_dataframes(*analyzed_dfs)
    return union_df.orderBy("column_number")

convert_all_maps_to_arrays(df: DataFrame) -> DataFrame

Transform all columns of type Map<K,V> inside the given DataFrame into ARRAY<STRUCT<key: K, value: V>>. This transformation works recursively on every nesting level.

Info

This method is compatible with any schema. It recursively applies on structs, arrays and maps and is compatible with field names containing special characters.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required

Returns:

Type Description
DataFrame

A new DataFrame in which all maps have been replaced with arrays of entries.

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('SELECT 1 as id, ARRAY(MAP(1, STRUCT(MAP(1, "a") as m2))) as m1')
>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- m1: array (nullable = false)
 |    |-- element: map (containsNull = false)
 |    |    |-- key: integer
 |    |    |-- value: struct (valueContainsNull = false)
 |    |    |    |-- m2: map (nullable = false)
 |    |    |    |    |-- key: integer
 |    |    |    |    |-- value: string (valueContainsNull = false)

>>> df.show()
+---+-------------------+
| id|                 m1|
+---+-------------------+
|  1|[{1 -> {{1 -> a}}}]|
+---+-------------------+

>>> res_df = convert_all_maps_to_arrays(df)
>>> res_df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- m1: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: struct (containsNull = false)
 |    |    |    |-- key: integer (nullable = false)
 |    |    |    |-- value: struct (nullable = false)
 |    |    |    |    |-- m2: array (nullable = false)
 |    |    |    |    |    |-- element: struct (containsNull = false)
 |    |    |    |    |    |    |-- key: integer (nullable = false)
 |    |    |    |    |    |    |-- value: string (nullable = false)

>>> res_df.show()
+---+-------------------+
| id|                 m1|
+---+-------------------+
|  1|[[{1, {[{1, a}]}}]]|
+---+-------------------+
Source code in spark_frame/transformations_impl/convert_all_maps_to_arrays.py
def convert_all_maps_to_arrays(df: DataFrame) -> DataFrame:
    """Transform all columns of type `Map<K,V>` inside the given DataFrame into `ARRAY<STRUCT<key: K, value: V>>`.
    This transformation works recursively on every nesting level.

    !!! info
        This method is compatible with any schema. It recursively applies on structs, arrays and maps
        and is compatible with field names containing special characters.

    Args:
        df: A Spark DataFrame

    Returns:
        A new DataFrame in which all maps have been replaced with arrays of entries.

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('SELECT 1 as id, ARRAY(MAP(1, STRUCT(MAP(1, "a") as m2))) as m1')
        >>> df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- m1: array (nullable = false)
         |    |-- element: map (containsNull = false)
         |    |    |-- key: integer
         |    |    |-- value: struct (valueContainsNull = false)
         |    |    |    |-- m2: map (nullable = false)
         |    |    |    |    |-- key: integer
         |    |    |    |    |-- value: string (valueContainsNull = false)
        <BLANKLINE>
        >>> df.show()
        +---+-------------------+
        | id|                 m1|
        +---+-------------------+
        |  1|[{1 -> {{1 -> a}}}]|
        +---+-------------------+
        <BLANKLINE>
        >>> res_df = convert_all_maps_to_arrays(df)
        >>> res_df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- m1: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: struct (containsNull = false)
         |    |    |    |-- key: integer (nullable = false)
         |    |    |    |-- value: struct (nullable = false)
         |    |    |    |    |-- m2: array (nullable = false)
         |    |    |    |    |    |-- element: struct (containsNull = false)
         |    |    |    |    |    |    |-- key: integer (nullable = false)
         |    |    |    |    |    |    |-- value: string (nullable = false)
        <BLANKLINE>
        >>> res_df.show()
        +---+-------------------+
        | id|                 m1|
        +---+-------------------+
        |  1|[[{1, {[{1, a}]}}]]|
        +---+-------------------+
        <BLANKLINE>
    """

    def map_to_arrays(col: Column, data_type: DataType) -> Optional[Column]:
        if isinstance(data_type, MapType):
            return f.map_entries(col)
        else:
            return None

    return transform_all_fields(df, map_to_arrays)

flatten(df: DataFrame, struct_separator: str = '.') -> DataFrame

Flatten all the struct columns of a Spark DataFrame. Nested fields names will be joined together using the specified separator

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
struct_separator str

A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots

'.'

Returns:

Type Description
DataFrame

A flattened DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame(
...         [(1, {"a": 1, "b": {"c": 1, "d": 1}})],
...         "id INT, s STRUCT<a:INT, b:STRUCT<c:INT, d:INT>>"
...      )
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s: struct (nullable = true)
 |    |-- a: integer (nullable = true)
 |    |-- b: struct (nullable = true)
 |    |    |-- c: integer (nullable = true)
 |    |    |-- d: integer (nullable = true)

>>> flatten(df).printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.a: integer (nullable = true)
 |-- s.b.c: integer (nullable = true)
 |-- s.b.d: integer (nullable = true)

>>> df = spark.createDataFrame(
...         [(1, {"a.a1": 1, "b.b1": {"c.c1": 1, "d.d1": 1}})],
...         "id INT, `s.s1` STRUCT<`a.a1`:INT, `b.b1`:STRUCT<`c.c1`:INT, `d.d1`:INT>>"
... )
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1: struct (nullable = true)
 |    |-- a.a1: integer (nullable = true)
 |    |-- b.b1: struct (nullable = true)
 |    |    |-- c.c1: integer (nullable = true)
 |    |    |-- d.d1: integer (nullable = true)

>>> flatten(df, "?").printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1?a.a1: integer (nullable = true)
 |-- s.s1?b.b1?c.c1: integer (nullable = true)
 |-- s.s1?b.b1?d.d1: integer (nullable = true)
Source code in spark_frame/transformations_impl/flatten.py
def flatten(df: DataFrame, struct_separator: str = ".") -> DataFrame:
    """Flatten all the struct columns of a Spark [DataFrame][pyspark.sql.DataFrame].
    Nested fields names will be joined together using the specified separator

    Args:
        df: A Spark DataFrame
        struct_separator: A string used to separate the structs names from their elements.
            It might be useful to change the separator when some DataFrame's column names already contain dots

    Returns:
        A flattened DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame(
        ...         [(1, {"a": 1, "b": {"c": 1, "d": 1}})],
        ...         "id INT, s STRUCT<a:INT, b:STRUCT<c:INT, d:INT>>"
        ...      )
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s: struct (nullable = true)
         |    |-- a: integer (nullable = true)
         |    |-- b: struct (nullable = true)
         |    |    |-- c: integer (nullable = true)
         |    |    |-- d: integer (nullable = true)
        <BLANKLINE>
        >>> flatten(df).printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.a: integer (nullable = true)
         |-- s.b.c: integer (nullable = true)
         |-- s.b.d: integer (nullable = true)
        <BLANKLINE>
        >>> df = spark.createDataFrame(
        ...         [(1, {"a.a1": 1, "b.b1": {"c.c1": 1, "d.d1": 1}})],
        ...         "id INT, `s.s1` STRUCT<`a.a1`:INT, `b.b1`:STRUCT<`c.c1`:INT, `d.d1`:INT>>"
        ... )
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1: struct (nullable = true)
         |    |-- a.a1: integer (nullable = true)
         |    |-- b.b1: struct (nullable = true)
         |    |    |-- c.c1: integer (nullable = true)
         |    |    |-- d.d1: integer (nullable = true)
        <BLANKLINE>
        >>> flatten(df, "?").printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1?a.a1: integer (nullable = true)
         |-- s.s1?b.b1?c.c1: integer (nullable = true)
         |-- s.s1?b.b1?d.d1: integer (nullable = true)
        <BLANKLINE>

    """
    # The idea is to recursively write a "SELECT s.b.c as `s.b.c`" for each nested column.
    cols = []

    def expand_struct(struct: StructType, col_stack: List[str]) -> None:
        for field in struct:
            if isinstance(field.dataType, StructType):
                struct_field = field.dataType
                expand_struct(struct_field, [*col_stack, field.name])
            else:
                column = f.col(".".join(quote_columns([*col_stack, field.name])))
                cols.append(column.alias(struct_separator.join([*col_stack, field.name])))

    expand_struct(df.schema, col_stack=[])
    return df.select(cols)

flatten_all_arrays(df: DataFrame) -> DataFrame

Flatten all columns of type ARRAY<ARRAY<T>> inside the given DataFrame into ARRAY<<T>>>. This transformation works recursively on every nesting level.

Info

This method is compatible with any schema. It recursively applies on structs, arrays and maps and accepts field names containing dots (.), exclamation marks (!) or percentage (%).

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required

Returns:

Type Description
DataFrame

A new DataFrame in which all arrays of array have been flattened

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('SELECT 1 as id, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3)), ARRAY(ARRAY(4), ARRAY(5))) as a')
>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- a: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: array (containsNull = false)
 |    |    |    |-- element: integer (containsNull = false)

>>> df.show(truncate=False)
+---+---------------------------+
|id |a                          |
+---+---------------------------+
|1  |[[[1, 2], [3]], [[4], [5]]]|
+---+---------------------------+

>>> res_df = flatten_all_arrays(df)
>>> res_df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- a: array (nullable = false)
 |    |-- element: integer (containsNull = false)

>>> res_df.show(truncate=False)
+---+---------------+
|id |a              |
+---+---------------+
|1  |[1, 2, 3, 4, 5]|
+---+---------------+
Source code in spark_frame/transformations_impl/flatten_all_arrays.py
def flatten_all_arrays(df: DataFrame) -> DataFrame:
    """Flatten all columns of type `ARRAY<ARRAY<T>>` inside the given DataFrame into `ARRAY<<T>>>`.
    This transformation works recursively on every nesting level.

    !!! info
        This method is compatible with any schema. It recursively applies on structs, arrays and maps
        and accepts field names containing dots (`.`), exclamation marks (`!`) or percentage (`%`).

    Args:
        df: A Spark DataFrame

    Returns:
        A new DataFrame in which all arrays of array have been flattened

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('SELECT 1 as id, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3)), ARRAY(ARRAY(4), ARRAY(5))) as a')
        >>> df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- a: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: array (containsNull = false)
         |    |    |    |-- element: integer (containsNull = false)
        <BLANKLINE>
        >>> df.show(truncate=False)
        +---+---------------------------+
        |id |a                          |
        +---+---------------------------+
        |1  |[[[1, 2], [3]], [[4], [5]]]|
        +---+---------------------------+
        <BLANKLINE>
        >>> res_df = flatten_all_arrays(df)
        >>> res_df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- a: array (nullable = false)
         |    |-- element: integer (containsNull = false)
        <BLANKLINE>
        >>> res_df.show(truncate=False)
        +---+---------------+
        |id |a              |
        +---+---------------+
        |1  |[1, 2, 3, 4, 5]|
        +---+---------------+
        <BLANKLINE>
    """

    def flatten_array(col: Column, data_type: DataType) -> Optional[Column]:
        if isinstance(data_type, ArrayType) and isinstance(data_type.elementType, ArrayType):
            return f.flatten(col)
        else:
            return None

    return transform_all_fields(df, flatten_array)

harmonize_dataframes(left_df: DataFrame, right_df: DataFrame, common_columns: Optional[Dict[str, Optional[str]]] = None, keep_missing_columns: bool = False) -> Tuple[DataFrame, DataFrame]

Given two DataFrames, returns two new corresponding DataFrames with the same schemas by applying the following changes:

  • Only common columns are kept
  • Columns of type MAP are cast into ARRAY>
  • Columns are re-ordered to have the same ordering in both DataFrames
  • When matching columns have different types, their type is widened to their most narrow common type. This transformation is applied recursively on nested columns, including those inside repeated records (a.k.a. ARRAY<STRUCT<>>).

Parameters:

Name Type Description Default
left_df DataFrame

A Spark DataFrame

required
right_df DataFrame

A Spark DataFrame

required
common_columns Optional[Dict[str, Optional[str]]]

A dict of (column name, type). Column names must appear in both DataFrames, and each column will be cast into the corresponding type.

None
keep_missing_columns bool

If set to true, the root columns of each DataFrames that do not exist in the other one are kept.

False

Returns:

Type Description
Tuple[DataFrame, DataFrame]

Two new Spark DataFrames with the same schema

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df1 = spark.sql('SELECT 1 as id, STRUCT(2 as a, ARRAY(STRUCT(3 as b, 4 as c)) as s2) as s1')
>>> df2 = spark.sql('SELECT 1 as id, STRUCT(2 as a, ARRAY(STRUCT(3.0 as b, "4" as c, 5 as d)) as s2) as s1')
>>> df1.union(df2).show(truncate=False)
Traceback (most recent call last):
    ...
AnalysisException: ... UNION can only be performed on tables with compatible column types.
>>> df1, df2 = harmonize_dataframes(df1, df2)
>>> df1.union(df2).show()
+---+---------------+
| id|             s1|
+---+---------------+
|  1|{2, [{3.0, 4}]}|
|  1|{2, [{3.0, 4}]}|
+---+---------------+

>>> df1, df2 = harmonize_dataframes(df1, df2, common_columns={"id": None, "s1.s2!.b": "int"})
>>> df1.union(df2).show()
+---+-------+
| id|     s1|
+---+-------+
|  1|{[{3}]}|
|  1|{[{3}]}|
+---+-------+
Source code in spark_frame/transformations_impl/harmonize_dataframes.py
def harmonize_dataframes(
    left_df: DataFrame,
    right_df: DataFrame,
    common_columns: Optional[Dict[str, Optional[str]]] = None,
    keep_missing_columns: bool = False,
) -> Tuple[DataFrame, DataFrame]:
    """Given two DataFrames, returns two new corresponding DataFrames with the same schemas by applying the following
    changes:

    - Only common columns are kept
    - Columns of type MAP<key, value> are cast into ARRAY<STRUCT<key, value>>
    - Columns are re-ordered to have the same ordering in both DataFrames
    - When matching columns have different types, their type is widened to their most narrow common type.
    This transformation is applied recursively on nested columns, including those inside
    repeated records (a.k.a. ARRAY<STRUCT<>>).

    Args:
        left_df: A Spark DataFrame
        right_df: A Spark DataFrame
        common_columns: A dict of (column name, type).
            Column names must appear in both DataFrames, and each column will be cast into the corresponding type.
        keep_missing_columns: If set to true, the root columns of each DataFrames that do not exist in the other
            one are kept.

    Returns:
        Two new Spark DataFrames with the same schema

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df1 = spark.sql('SELECT 1 as id, STRUCT(2 as a, ARRAY(STRUCT(3 as b, 4 as c)) as s2) as s1')
        >>> df2 = spark.sql('SELECT 1 as id, STRUCT(2 as a, ARRAY(STRUCT(3.0 as b, "4" as c, 5 as d)) as s2) as s1')
        >>> df1.union(df2).show(truncate=False) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
            ...
        AnalysisException: ... UNION can only be performed on tables with compatible column types.
        >>> df1, df2 = harmonize_dataframes(df1, df2)
        >>> df1.union(df2).show()
        +---+---------------+
        | id|             s1|
        +---+---------------+
        |  1|{2, [{3.0, 4}]}|
        |  1|{2, [{3.0, 4}]}|
        +---+---------------+
        <BLANKLINE>
        >>> df1, df2 = harmonize_dataframes(df1, df2, common_columns={"id": None, "s1.s2!.b": "int"})
        >>> df1.union(df2).show()
        +---+-------+
        | id|     s1|
        +---+-------+
        |  1|{[{3}]}|
        |  1|{[{3}]}|
        +---+-------+
        <BLANKLINE>
    """
    left_schema_flat = flatten_schema(left_df.schema, explode=True)
    right_schema_flat = flatten_schema(right_df.schema, explode=True)
    if common_columns is None:
        common_columns = get_common_columns(left_schema_flat, right_schema_flat)

    left_only_columns = {}
    right_only_columns = {}
    if keep_missing_columns:
        left_cols = [field.name for field in left_schema_flat.fields]
        right_cols = [field.name for field in right_schema_flat.fields]
        left_cols_set = set(left_cols)
        right_cols_set = set(right_cols)
        left_only_columns = {col: None for col in left_cols if col not in right_cols_set}
        right_only_columns = {col: None for col in right_cols if col not in left_cols_set}

    def build_col(col_name: str, col_type: Optional[str]) -> PrintableFunction:
        parent_structs = _deepest_granularity(col_name)
        if col_type is not None:
            tpe = col_type
            f1 = PrintableFunction(lambda s: s.cast(tpe), lambda s: f"{s}.cast({tpe})")
        else:
            f1 = higher_order.identity
        f2 = higher_order.recursive_struct_get(parent_structs)
        return fp.compose(f1, f2)

    left_columns = {**common_columns, **left_only_columns}
    right_columns = {**common_columns, **right_only_columns}
    left_columns_dict = {col_name: build_col(col_name, col_type) for (col_name, col_type) in left_columns.items()}
    right_columns_dict = {col_name: build_col(col_name, col_type) for (col_name, col_type) in right_columns.items()}
    left_tree = _build_nested_struct_tree(left_columns_dict)
    right_tree = _build_nested_struct_tree(right_columns_dict)
    left_root_transformation = _build_transformation_from_tree(left_tree)
    right_root_transformation = _build_transformation_from_tree(right_tree)
    return (
        left_df.select(*left_root_transformation([left_df])),
        right_df.select(*right_root_transformation([right_df])),
    )

parse_json_columns(df: DataFrame, columns: Union[str, List[str], Dict[str, str]]) -> DataFrame

Transform the specified columns containing json strings in the given DataFrame into structs containing the equivalent parsed information.

This method is similar to Spark's from_json function, with one main difference: from_json requires the user to pass the expected json schema, while this method performs a first pass on the DataFrame to detect automatically the json schema of each column.

By default, the output columns will have the same name as the input columns, but if you want to keep the input columns you can pass a dict(input_col_name, output_col_name) to specify different output column names.

Please be aware that automatic schema detection is not very robust, and while this method can be quite helpful for quick prototyping and data exploration, it is recommended to use a fixed schema and make sure the schema of the input json data is properly enforce, or at the very least use schema have a drift detection mechanism.

Warning

This method's performances are not optimal, has it has to perform a Python operation on the executor's side.

Warning

When you use this method on a column that is inside a struct (e.g. column "a.b.c"), instead of replacing that column it will create a new column outside the struct (e.g. "a.b.c") (See Example 2).

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
columns Union[str, List[str], Dict[str, str]]

A column name, list of column names, or dict(column_name, parsed_column_name)

required

Returns:

Type Description
DataFrame

A new DataFrame

Example 1

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame([
...         (1, '[{"a": 1}, {"a": 2}]'),
...         (1, '[{"a": 2}, {"a": 4}]'),
...         (2, None)
...     ], "id INT, json1 STRING"
... )
>>> df.show()
+---+--------------------+
| id|               json1|
+---+--------------------+
|  1|[{"a": 1}, {"a": 2}]|
|  1|[{"a": 2}, {"a": 4}]|
|  2|                NULL|
+---+--------------------+

>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- json1: string (nullable = true)

>>> parse_json_columns(df, 'json1').printSchema()
root
 |-- id: integer (nullable = true)
 |-- json1: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: long (nullable = true)

Example 2 : different output column name

>>> parse_json_columns(df, {'json1': 'parsed_json1'}).printSchema()
root
 |-- id: integer (nullable = true)
 |-- json1: string (nullable = true)
 |-- parsed_json1: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: long (nullable = true)

Example 3 : json inside a struct

>>> df = spark.createDataFrame([
...         (1, {'json1': '[{"a": 1}, {"a": 2}]'}),
...         (1, {'json1': '[{"a": 2}, {"a": 4}]'}),
...         (2, None)
...     ], "id INT, struct STRUCT<json1: STRING>"
... )
>>> df.show(10, False)
+---+----------------------+
|id |struct                |
+---+----------------------+
|1  |{[{"a": 1}, {"a": 2}]}|
|1  |{[{"a": 2}, {"a": 4}]}|
|2  |NULL                  |
+---+----------------------+

>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- struct: struct (nullable = true)
 |    |-- json1: string (nullable = true)

>>> res = parse_json_columns(df, 'struct.json1')
>>> res.printSchema()
root
 |-- id: integer (nullable = true)
 |-- struct: struct (nullable = true)
 |    |-- json1: string (nullable = true)
 |-- struct.json1: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: long (nullable = true)

>>> res.show(10, False)
+---+----------------------+------------+
|id |struct                |struct.json1|
+---+----------------------+------------+
|1  |{[{"a": 1}, {"a": 2}]}|[{1}, {2}]  |
|1  |{[{"a": 2}, {"a": 4}]}|[{2}, {4}]  |
|2  |NULL                  |NULL        |
+---+----------------------+------------+
Source code in spark_frame/transformations_impl/parse_json_columns.py
def parse_json_columns(df: DataFrame, columns: Union[str, List[str], Dict[str, str]]) -> DataFrame:
    """Transform the specified columns containing json strings in the given DataFrame into structs containing
    the equivalent parsed information.

    This method is similar to Spark's `from_json` function, with one main difference: `from_json` requires the user
    to pass the expected json schema, while this method performs a first pass on the DataFrame to detect automatically
    the json schema of each column.

    By default, the output columns will have the same name as the input columns, but if you want to keep the input
    columns you can pass a dict(input_col_name, output_col_name) to specify different output column names.

    Please be aware that automatic schema detection is not very robust, and while this method can be quite helpful
    for quick prototyping and data exploration, it is recommended to use a fixed schema and make sure the schema
    of the input json data is properly enforce, or at the very least use schema have a drift detection mechanism.

    !!! warning
        This method's performances are not optimal, has it has to perform a Python operation on the executor's side.

    !!! warning
        When you use this method on a column that is inside a struct (e.g. column "a.b.c"),
        instead of replacing that column it will create a new column outside the struct (e.g. "`a.b.c`")
        (See Example 2).

    Args:
        df: A Spark DataFrame
        columns: A column name, list of column names, or dict(column_name, parsed_column_name)

    Returns:
        A new DataFrame

    Examples: Example 1
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame([
        ...         (1, '[{"a": 1}, {"a": 2}]'),
        ...         (1, '[{"a": 2}, {"a": 4}]'),
        ...         (2, None)
        ...     ], "id INT, json1 STRING"
        ... )
        >>> df.show()
        +---+--------------------+
        | id|               json1|
        +---+--------------------+
        |  1|[{"a": 1}, {"a": 2}]|
        |  1|[{"a": 2}, {"a": 4}]|
        |  2|                NULL|
        +---+--------------------+
        <BLANKLINE>
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- json1: string (nullable = true)
        <BLANKLINE>
        >>> parse_json_columns(df, 'json1').printSchema()
        root
         |-- id: integer (nullable = true)
         |-- json1: array (nullable = true)
         |    |-- element: struct (containsNull = true)
         |    |    |-- a: long (nullable = true)
        <BLANKLINE>

    Examples: Example 2 : different output column name
        >>> parse_json_columns(df, {'json1': 'parsed_json1'}).printSchema()
        root
         |-- id: integer (nullable = true)
         |-- json1: string (nullable = true)
         |-- parsed_json1: array (nullable = true)
         |    |-- element: struct (containsNull = true)
         |    |    |-- a: long (nullable = true)
        <BLANKLINE>

    Examples: Example 3 : json inside a struct
        >>> df = spark.createDataFrame([
        ...         (1, {'json1': '[{"a": 1}, {"a": 2}]'}),
        ...         (1, {'json1': '[{"a": 2}, {"a": 4}]'}),
        ...         (2, None)
        ...     ], "id INT, struct STRUCT<json1: STRING>"
        ... )
        >>> df.show(10, False)
        +---+----------------------+
        |id |struct                |
        +---+----------------------+
        |1  |{[{"a": 1}, {"a": 2}]}|
        |1  |{[{"a": 2}, {"a": 4}]}|
        |2  |NULL                  |
        +---+----------------------+
        <BLANKLINE>
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- struct: struct (nullable = true)
         |    |-- json1: string (nullable = true)
        <BLANKLINE>
        >>> res = parse_json_columns(df, 'struct.json1')
        >>> res.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- struct: struct (nullable = true)
         |    |-- json1: string (nullable = true)
         |-- struct.json1: array (nullable = true)
         |    |-- element: struct (containsNull = true)
         |    |    |-- a: long (nullable = true)
        <BLANKLINE>
        >>> res.show(10, False)
        +---+----------------------+------------+
        |id |struct                |struct.json1|
        +---+----------------------+------------+
        |1  |{[{"a": 1}, {"a": 2}]}|[{1}, {2}]  |
        |1  |{[{"a": 2}, {"a": 4}]}|[{2}, {4}]  |
        |2  |NULL                  |NULL        |
        +---+----------------------+------------+
        <BLANKLINE>

    """
    if isinstance(columns, str):
        columns = [columns]
    if isinstance(columns, list):
        columns = {col: col for col in columns}

    wrapped_df = __wrap_json_columns(df, columns)
    schema_per_col = __infer_schema_per_column(wrapped_df, list(columns.values()))
    res = __parse_json_columns(wrapped_df, schema_per_col)
    return res

sort_all_arrays(df: DataFrame) -> DataFrame

Given a DataFrame, sort all fields of type ARRAY in a canonical order, making them comparable. This also applies to nested fields, even those inside other arrays.

Limitation

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required

Returns:

Type Description
DataFrame

A new DataFrame where all arrays have been sorted.

Example 1: with a simple `ARRAY<INT>`

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('SELECT 1 as id, ARRAY(3, 2, 1) as a')
>>> df.show()
+---+---------+
| id|        a|
+---+---------+
|  1|[3, 2, 1]|
+---+---------+

>>> sort_all_arrays(df).show()
+---+---------+
| id|        a|
+---+---------+
|  1|[1, 2, 3]|
+---+---------+

Example 2: with an `ARRAY<STRUCT<...>>`

>>> df = spark.sql('SELECT ARRAY(STRUCT(2 as a, 1 as b), STRUCT(1 as a, 2 as b), STRUCT(1 as a, 1 as b)) as s')
>>> df.show(truncate=False)
+------------------------+
|s                       |
+------------------------+
|[{2, 1}, {1, 2}, {1, 1}]|
+------------------------+

>>> df.transform(sort_all_arrays).show(truncate=False)
+------------------------+
|s                       |
+------------------------+
|[{1, 1}, {1, 2}, {2, 1}]|
+------------------------+

Example 3: with an `ARRAY<STRUCT<STRUCT<...>>>`

>>> df = spark.sql('''SELECT ARRAY(
...         STRUCT(STRUCT(2 as a, 2 as b) as s),
...         STRUCT(STRUCT(1 as a, 2 as b) as s)
...     ) as l1
... ''')
>>> df.show(truncate=False)
+--------------------+
|l1                  |
+--------------------+
|[{{2, 2}}, {{1, 2}}]|
+--------------------+

>>> df.transform(sort_all_arrays).show(truncate=False)
+--------------------+
|l1                  |
+--------------------+
|[{{1, 2}}, {{2, 2}}]|
+--------------------+

Example 4: with an `ARRAY<ARRAY<ARRAY<INT>>>`

As this example shows, the innermost arrays are sorted before the outermost arrays.

>>> df = spark.sql('''SELECT ARRAY(
...         ARRAY(ARRAY(4, 1), ARRAY(3, 2)),
...         ARRAY(ARRAY(2, 2), ARRAY(2, 1))
...     ) as l1
... ''')
>>> df.show(truncate=False)
+------------------------------------+
|l1                                  |
+------------------------------------+
|[[[4, 1], [3, 2]], [[2, 2], [2, 1]]]|
+------------------------------------+

>>> df.transform(sort_all_arrays).show(truncate=False)
+------------------------------------+
|l1                                  |
+------------------------------------+
|[[[1, 2], [2, 2]], [[1, 4], [2, 3]]]|
+------------------------------------+
Source code in spark_frame/transformations_impl/sort_all_arrays.py
def sort_all_arrays(df: DataFrame) -> DataFrame:
    """Given a DataFrame, sort all fields of type `ARRAY` in a canonical order, making them comparable.
    This also applies to nested fields, even those inside other arrays.

    !!! warning "Limitation"
        - Arrays containing sub-fields of type Map cannot be sorted, as the Map type is not comparable.
        - A possible workaround is to first use the transformation
        [`spark_frame.transformations.convert_all_maps_to_arrays`]
        [spark_frame.transformations_impl.convert_all_maps_to_arrays.convert_all_maps_to_arrays]

    Args:
        df: A Spark DataFrame

    Returns:
        A new DataFrame where all arrays have been sorted.

    Examples: Example 1: with a simple `ARRAY<INT>`
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('SELECT 1 as id, ARRAY(3, 2, 1) as a')
        >>> df.show()
        +---+---------+
        | id|        a|
        +---+---------+
        |  1|[3, 2, 1]|
        +---+---------+
        <BLANKLINE>
        >>> sort_all_arrays(df).show()
        +---+---------+
        | id|        a|
        +---+---------+
        |  1|[1, 2, 3]|
        +---+---------+
        <BLANKLINE>

    Examples: Example 2: with an `ARRAY<STRUCT<...>>`
        >>> df = spark.sql('SELECT ARRAY(STRUCT(2 as a, 1 as b), STRUCT(1 as a, 2 as b), STRUCT(1 as a, 1 as b)) as s')
        >>> df.show(truncate=False)
        +------------------------+
        |s                       |
        +------------------------+
        |[{2, 1}, {1, 2}, {1, 1}]|
        +------------------------+
        <BLANKLINE>
        >>> df.transform(sort_all_arrays).show(truncate=False)
        +------------------------+
        |s                       |
        +------------------------+
        |[{1, 1}, {1, 2}, {2, 1}]|
        +------------------------+
        <BLANKLINE>

    Examples: Example 3: with an `ARRAY<STRUCT<STRUCT<...>>>`
        >>> df = spark.sql('''SELECT ARRAY(
        ...         STRUCT(STRUCT(2 as a, 2 as b) as s),
        ...         STRUCT(STRUCT(1 as a, 2 as b) as s)
        ...     ) as l1
        ... ''')
        >>> df.show(truncate=False)
        +--------------------+
        |l1                  |
        +--------------------+
        |[{{2, 2}}, {{1, 2}}]|
        +--------------------+
        <BLANKLINE>
        >>> df.transform(sort_all_arrays).show(truncate=False)
        +--------------------+
        |l1                  |
        +--------------------+
        |[{{1, 2}}, {{2, 2}}]|
        +--------------------+
        <BLANKLINE>

    Examples: Example 4: with an `ARRAY<ARRAY<ARRAY<INT>>>`
        As this example shows, the innermost arrays are sorted before the outermost arrays.
        >>> df = spark.sql('''SELECT ARRAY(
        ...         ARRAY(ARRAY(4, 1), ARRAY(3, 2)),
        ...         ARRAY(ARRAY(2, 2), ARRAY(2, 1))
        ...     ) as l1
        ... ''')
        >>> df.show(truncate=False)
        +------------------------------------+
        |l1                                  |
        +------------------------------------+
        |[[[4, 1], [3, 2]], [[2, 2], [2, 1]]]|
        +------------------------------------+
        <BLANKLINE>
        >>> df.transform(sort_all_arrays).show(truncate=False)
        +------------------------------------+
        |l1                                  |
        +------------------------------------+
        |[[[1, 2], [2, 2]], [[1, 4], [2, 3]]]|
        +------------------------------------+
        <BLANKLINE>
    """

    def sort_array(col: Column, data_type: DataType) -> Optional[Column]:
        if isinstance(data_type, ArrayType):
            return f.sort_array(col)
        else:
            return None

    return transform_all_fields(df, sort_array)

transform_all_field_names(df: DataFrame, transformation: Callable[[str], str]) -> DataFrame

Apply a transformation to all nested field names of a DataFrame.

Info

This method is compatible with any schema. It recursively applies on structs, arrays and maps and is compatible with field names containing special characters.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
transformation Callable[[str], str]

Transformation to apply to all field names in the DataFrame.

required

Returns:

Type Description
DataFrame

A new DataFrame

Example 1: with a nested schema structure

In this example we cast all the field names of the schema to uppercase:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT
...     "John" as name,
...     ARRAY(STRUCT(1 as a), STRUCT(2 as a)) as s1,
...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s2,
...     ARRAY(ARRAY(STRUCT(1 as a)), ARRAY(STRUCT(2 as a))) as s3,
...     ARRAY(STRUCT(ARRAY(1, 2) as a), STRUCT(ARRAY(3, 4) as a)) as s4,
...     ARRAY(
...         STRUCT(ARRAY(STRUCT(STRUCT(1 as c) as b), STRUCT(STRUCT(2 as c) as b)) as a),
...         STRUCT(ARRAY(STRUCT(STRUCT(3 as c) as b), STRUCT(STRUCT(4 as c) as b)) as a)
...     ) as s5
... ''')
>>> nested.print_schema(df)
root
 |-- name: string (nullable = false)
 |-- s1!.a: integer (nullable = false)
 |-- s2!!: integer (nullable = false)
 |-- s3!!.a: integer (nullable = false)
 |-- s4!.a!: integer (nullable = false)
 |-- s5!.a!.b.c: integer (nullable = false)
>>> new_df = df.transform(transform_all_field_names, lambda s: s.upper())
>>> nested.print_schema(new_df)
root
 |-- NAME: string (nullable = false)
 |-- S1!.A: integer (nullable = false)
 |-- S2!!: integer (nullable = false)
 |-- S3!!.A: integer (nullable = false)
 |-- S4!.A!: integer (nullable = false)
 |-- S5!.A!.B.C: integer (nullable = false)

Example 2: sanitizing field names

In this example we replace all dots and exclamation marks in field names with underscores. This is useful to make a DataFrame compatible with the spark_frame.nested module.

>>> df = spark.sql('''SELECT
...     ARRAY(STRUCT(
...         ARRAY(STRUCT(
...             STRUCT(1 as `d.d!d`) as `c.c!c`
...         )) as `b.b!b`
...    )) as `a.a!a`
... ''')
>>> df.printSchema()
root
 |-- a.a!a: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- b.b!b: array (nullable = false)
 |    |    |    |-- element: struct (containsNull = false)
 |    |    |    |    |-- c.c!c: struct (nullable = false)
 |    |    |    |    |    |-- d.d!d: integer (nullable = false)

>>> new_df = df.transform(transform_all_field_names, lambda s: s.replace(".","_").replace("!", "_"))
>>> new_df.printSchema()
root
 |-- a_a_a: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- b_b_b: array (nullable = false)
 |    |    |    |-- element: struct (containsNull = false)
 |    |    |    |    |-- c_c_c: struct (nullable = false)
 |    |    |    |    |    |-- d_d_d: integer (nullable = false)

This also works on fields of type MAP<K,V>.

>>> df = spark.sql('SELECT MAP(STRUCT(1 as `a.a!a`), STRUCT(2 as `b.b!b`)) as `m.m!m`')
>>> df.printSchema()
root
 |-- m.m!m: map (nullable = false)
 |    |-- key: struct
 |    |    |-- a.a!a: integer (nullable = false)
 |    |-- value: struct (valueContainsNull = false)
 |    |    |-- b.b!b: integer (nullable = false)

>>> new_df = df.transform(transform_all_field_names, lambda s: s.replace(".","_").replace("!", "_"))
>>> new_df.printSchema()
root
 |-- m_m_m: map (nullable = false)
 |    |-- key: struct
 |    |    |-- a_a_a: integer (nullable = false)
 |    |-- value: struct (valueContainsNull = false)
 |    |    |-- b_b_b: integer (nullable = false)
Source code in spark_frame/transformations_impl/transform_all_field_names.py
def transform_all_field_names(df: DataFrame, transformation: Callable[[str], str]) -> DataFrame:
    """Apply a transformation to all nested field names of a DataFrame.

    !!! info
        This method is compatible with any schema. It recursively applies on structs, arrays and maps
        and is compatible with field names containing special characters.

    Args:
        df: A Spark DataFrame
        transformation: Transformation to apply to all field names in the DataFrame.

    Returns:
        A new DataFrame

    Examples: Example 1: with a nested schema structure
        In this example we cast all the field names of the schema to uppercase:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT
        ...     "John" as name,
        ...     ARRAY(STRUCT(1 as a), STRUCT(2 as a)) as s1,
        ...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s2,
        ...     ARRAY(ARRAY(STRUCT(1 as a)), ARRAY(STRUCT(2 as a))) as s3,
        ...     ARRAY(STRUCT(ARRAY(1, 2) as a), STRUCT(ARRAY(3, 4) as a)) as s4,
        ...     ARRAY(
        ...         STRUCT(ARRAY(STRUCT(STRUCT(1 as c) as b), STRUCT(STRUCT(2 as c) as b)) as a),
        ...         STRUCT(ARRAY(STRUCT(STRUCT(3 as c) as b), STRUCT(STRUCT(4 as c) as b)) as a)
        ...     ) as s5
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- name: string (nullable = false)
         |-- s1!.a: integer (nullable = false)
         |-- s2!!: integer (nullable = false)
         |-- s3!!.a: integer (nullable = false)
         |-- s4!.a!: integer (nullable = false)
         |-- s5!.a!.b.c: integer (nullable = false)
        <BLANKLINE>

        >>> new_df = df.transform(transform_all_field_names, lambda s: s.upper())
        >>> nested.print_schema(new_df)
        root
         |-- NAME: string (nullable = false)
         |-- S1!.A: integer (nullable = false)
         |-- S2!!: integer (nullable = false)
         |-- S3!!.A: integer (nullable = false)
         |-- S4!.A!: integer (nullable = false)
         |-- S5!.A!.B.C: integer (nullable = false)
        <BLANKLINE>

    Examples: Example 2: sanitizing field names
        In this example we replace all dots and exclamation marks in field names with underscores.
        This is useful to make a DataFrame compatible with the [spark_frame.nested](/spark-frame/reference/nested)
        module.
        >>> df = spark.sql('''SELECT
        ...     ARRAY(STRUCT(
        ...         ARRAY(STRUCT(
        ...             STRUCT(1 as `d.d!d`) as `c.c!c`
        ...         )) as `b.b!b`
        ...    )) as `a.a!a`
        ... ''')
        >>> df.printSchema()
        root
         |-- a.a!a: array (nullable = false)
         |    |-- element: struct (containsNull = false)
         |    |    |-- b.b!b: array (nullable = false)
         |    |    |    |-- element: struct (containsNull = false)
         |    |    |    |    |-- c.c!c: struct (nullable = false)
         |    |    |    |    |    |-- d.d!d: integer (nullable = false)
        <BLANKLINE>
        >>> new_df = df.transform(transform_all_field_names, lambda s: s.replace(".","_").replace("!", "_"))
        >>> new_df.printSchema()
        root
         |-- a_a_a: array (nullable = false)
         |    |-- element: struct (containsNull = false)
         |    |    |-- b_b_b: array (nullable = false)
         |    |    |    |-- element: struct (containsNull = false)
         |    |    |    |    |-- c_c_c: struct (nullable = false)
         |    |    |    |    |    |-- d_d_d: integer (nullable = false)
        <BLANKLINE>

        This also works on fields of type `MAP<K,V>`.
        >>> df = spark.sql('SELECT MAP(STRUCT(1 as `a.a!a`), STRUCT(2 as `b.b!b`)) as `m.m!m`')
        >>> df.printSchema()
        root
         |-- m.m!m: map (nullable = false)
         |    |-- key: struct
         |    |    |-- a.a!a: integer (nullable = false)
         |    |-- value: struct (valueContainsNull = false)
         |    |    |-- b.b!b: integer (nullable = false)
        <BLANKLINE>
        >>> new_df = df.transform(transform_all_field_names, lambda s: s.replace(".","_").replace("!", "_"))
        >>> new_df.printSchema()
        root
         |-- m_m_m: map (nullable = false)
         |    |-- key: struct
         |    |    |-- a_a_a: integer (nullable = false)
         |    |-- value: struct (valueContainsNull = false)
         |    |    |-- b_b_b: integer (nullable = false)
        <BLANKLINE>
    """
    root_transformation = build_transformation_from_schema(df.schema, name_transformation=transformation)
    return df.select(*root_transformation(df))

transform_all_fields(df: DataFrame, transformation: Callable[[Column, DataType], Optional[Column]]) -> DataFrame

Apply a transformation to all nested fields of a DataFrame.

Info

This method is compatible with any schema. It recursively applies on structs, arrays and maps and is compatible with field names containing special characters.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
transformation Callable[[Column, DataType], Optional[Column]]

Transformation to apply to all fields of the DataFrame. The transformation must take as input a Column expression and the DataType of the corresponding expression.

required

Returns:

Type Description
DataFrame

A new DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT
...     "John" as name,
...     ARRAY(STRUCT(1 as a), STRUCT(2 as a)) as s1,
...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s2,
...     ARRAY(ARRAY(STRUCT(1 as a)), ARRAY(STRUCT(2 as a))) as s3,
...     ARRAY(STRUCT(ARRAY(1, 2) as a), STRUCT(ARRAY(3, 4) as a)) as s4,
...     ARRAY(
...         STRUCT(ARRAY(STRUCT(STRUCT(1 as c) as b), STRUCT(STRUCT(2 as c) as b)) as a),
...         STRUCT(ARRAY(STRUCT(STRUCT(3 as c) as b), STRUCT(STRUCT(4 as c) as b)) as a)
...     ) as s5
... ''')
>>> nested.print_schema(df)
root
 |-- name: string (nullable = false)
 |-- s1!.a: integer (nullable = false)
 |-- s2!!: integer (nullable = false)
 |-- s3!!.a: integer (nullable = false)
 |-- s4!.a!: integer (nullable = false)
 |-- s5!.a!.b.c: integer (nullable = false)

>>> df.show(truncate=False)
+----+----------+----------------+--------------+--------------------+------------------------------------+
|name|s1        |s2              |s3            |s4                  |s5                                  |
+----+----------+----------------+--------------+--------------------+------------------------------------+
|John|[{1}, {2}]|[[1, 2], [3, 4]]|[[{1}], [{2}]]|[{[1, 2]}, {[3, 4]}]|[{[{{1}}, {{2}}]}, {[{{3}}, {{4}}]}]|
+----+----------+----------------+--------------+--------------------+------------------------------------+

>>> from pyspark.sql.types import IntegerType
>>> def cast_int_as_double(col: Column, data_type: DataType):
...     if isinstance(data_type, IntegerType):
...         return col.cast("DOUBLE")
>>> new_df = df.transform(transform_all_fields, cast_int_as_double)
>>> nested.print_schema(new_df)
root
 |-- name: string (nullable = false)
 |-- s1!.a: double (nullable = false)
 |-- s2!!: double (nullable = false)
 |-- s3!!.a: double (nullable = false)
 |-- s4!.a!: double (nullable = false)
 |-- s5!.a!.b.c: double (nullable = false)

>>> new_df.show(truncate=False)
+----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
|name|s1            |s2                      |s3                |s4                          |s5                                          |
+----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
|John|[{1.0}, {2.0}]|[[1.0, 2.0], [3.0, 4.0]]|[[{1.0}], [{2.0}]]|[{[1.0, 2.0]}, {[3.0, 4.0]}]|[{[{{1.0}}, {{2.0}}]}, {[{{3.0}}, {{4.0}}]}]|
+----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
Source code in spark_frame/transformations_impl/transform_all_fields.py
def transform_all_fields(
    df: DataFrame,
    transformation: Callable[[Column, DataType], Optional[Column]],
) -> DataFrame:
    """Apply a transformation to all nested fields of a DataFrame.

    !!! info
        This method is compatible with any schema. It recursively applies on structs, arrays and maps
        and is compatible with field names containing special characters.

    Args:
        df: A Spark DataFrame
        transformation: Transformation to apply to all fields of the DataFrame. The transformation must take as input
            a Column expression and the DataType of the corresponding expression.

    Returns:
        A new DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT
        ...     "John" as name,
        ...     ARRAY(STRUCT(1 as a), STRUCT(2 as a)) as s1,
        ...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s2,
        ...     ARRAY(ARRAY(STRUCT(1 as a)), ARRAY(STRUCT(2 as a))) as s3,
        ...     ARRAY(STRUCT(ARRAY(1, 2) as a), STRUCT(ARRAY(3, 4) as a)) as s4,
        ...     ARRAY(
        ...         STRUCT(ARRAY(STRUCT(STRUCT(1 as c) as b), STRUCT(STRUCT(2 as c) as b)) as a),
        ...         STRUCT(ARRAY(STRUCT(STRUCT(3 as c) as b), STRUCT(STRUCT(4 as c) as b)) as a)
        ...     ) as s5
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- name: string (nullable = false)
         |-- s1!.a: integer (nullable = false)
         |-- s2!!: integer (nullable = false)
         |-- s3!!.a: integer (nullable = false)
         |-- s4!.a!: integer (nullable = false)
         |-- s5!.a!.b.c: integer (nullable = false)
        <BLANKLINE>
        >>> df.show(truncate=False)
        +----+----------+----------------+--------------+--------------------+------------------------------------+
        |name|s1        |s2              |s3            |s4                  |s5                                  |
        +----+----------+----------------+--------------+--------------------+------------------------------------+
        |John|[{1}, {2}]|[[1, 2], [3, 4]]|[[{1}], [{2}]]|[{[1, 2]}, {[3, 4]}]|[{[{{1}}, {{2}}]}, {[{{3}}, {{4}}]}]|
        +----+----------+----------------+--------------+--------------------+------------------------------------+
        <BLANKLINE>
        >>> from pyspark.sql.types import IntegerType
        >>> def cast_int_as_double(col: Column, data_type: DataType):
        ...     if isinstance(data_type, IntegerType):
        ...         return col.cast("DOUBLE")
        >>> new_df = df.transform(transform_all_fields, cast_int_as_double)
        >>> nested.print_schema(new_df)
        root
         |-- name: string (nullable = false)
         |-- s1!.a: double (nullable = false)
         |-- s2!!: double (nullable = false)
         |-- s3!!.a: double (nullable = false)
         |-- s4!.a!: double (nullable = false)
         |-- s5!.a!.b.c: double (nullable = false)
        <BLANKLINE>
        >>> new_df.show(truncate=False)
        +----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
        |name|s1            |s2                      |s3                |s4                          |s5                                          |
        +----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
        |John|[{1.0}, {2.0}]|[[1.0, 2.0], [3.0, 4.0]]|[[{1.0}], [{2.0}]]|[{[1.0, 2.0]}, {[3.0, 4.0]}]|[{[{{1.0}}, {{2.0}}]}, {[{{3.0}}, {{4.0}}]}]|
        +----+--------------+------------------------+------------------+----------------------------+--------------------------------------------+
        <BLANKLINE>
    """  # noqa: E501
    root_transformation = build_transformation_from_schema(
        df.schema,
        column_transformation=transformation,
    )
    return df.select(*root_transformation(df))

unflatten(df: DataFrame, separator: str = '.') -> DataFrame

Reverse of the flatten operation Nested fields names will be separated from each other using the specified separator

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
separator str

A string used to separate the structs names from their elements. It might be useful to change the separator when some DataFrame's column names already contain dots

'.'

Returns:

Type Description
DataFrame

A flattened DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame([(1, 1, 1, 1)], "id INT, `s.a` INT, `s.b.c` INT, `s.b.d` INT")
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.a: integer (nullable = true)
 |-- s.b.c: integer (nullable = true)
 |-- s.b.d: integer (nullable = true)

>>> unflatten(df).printSchema()
root
 |-- id: integer (nullable = true)
 |-- s: struct (nullable = true)
 |    |-- a: integer (nullable = true)
 |    |-- b: struct (nullable = true)
 |    |    |-- c: integer (nullable = true)
 |    |    |-- d: integer (nullable = true)

>>> df = spark.createDataFrame([(1, 1, 1)], "id INT, `s.s1?a.a1` INT, `s.s1?b.b1` INT")
>>> df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1?a.a1: integer (nullable = true)
 |-- s.s1?b.b1: integer (nullable = true)

>>> unflatten(df, "?").printSchema()
root
 |-- id: integer (nullable = true)
 |-- s.s1: struct (nullable = true)
 |    |-- a.a1: integer (nullable = true)
 |    |-- b.b1: integer (nullable = true)
Source code in spark_frame/transformations_impl/unflatten.py
def unflatten(df: DataFrame, separator: str = ".") -> DataFrame:
    """Reverse of the flatten operation
    Nested fields names will be separated from each other using the specified separator

    Args:
        df: A Spark DataFrame
        separator: A string used to separate the structs names from their elements.
                   It might be useful to change the separator when some DataFrame's column names already contain dots

    Returns:
        A flattened DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame([(1, 1, 1, 1)], "id INT, `s.a` INT, `s.b.c` INT, `s.b.d` INT")
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.a: integer (nullable = true)
         |-- s.b.c: integer (nullable = true)
         |-- s.b.d: integer (nullable = true)
        <BLANKLINE>
        >>> unflatten(df).printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s: struct (nullable = true)
         |    |-- a: integer (nullable = true)
         |    |-- b: struct (nullable = true)
         |    |    |-- c: integer (nullable = true)
         |    |    |-- d: integer (nullable = true)
        <BLANKLINE>
        >>> df = spark.createDataFrame([(1, 1, 1)], "id INT, `s.s1?a.a1` INT, `s.s1?b.b1` INT")
        >>> df.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1?a.a1: integer (nullable = true)
         |-- s.s1?b.b1: integer (nullable = true)
        <BLANKLINE>
        >>> unflatten(df, "?").printSchema()
        root
         |-- id: integer (nullable = true)
         |-- s.s1: struct (nullable = true)
         |    |-- a.a1: integer (nullable = true)
         |    |-- b.b1: integer (nullable = true)
        <BLANKLINE>
    """
    # The idea is to recursively write a "SELECT struct(a, struct(s.b.c, s.b.d)) as s" for each nested column.
    # There is a little twist as we don't want to rebuild the struct if all its fields are NULL, so we add a CASE WHEN

    def has_structs(df: DataFrame) -> bool:
        struct_fields = [field for field in df.schema if is_struct(field)]
        return len(struct_fields) > 0

    if has_structs(df):
        df = flatten(df)

    tree = _build_nested_struct_tree(df.columns, separator)
    cols = _build_struct_from_tree(tree, separator)
    return df.select(cols)

union_dataframes(*dfs: DataFrame) -> DataFrame

Returns the union between multiple DataFrames

Parameters:

Name Type Description Default
dfs DataFrame

One or more Spark DataFrames

()

Returns:

Type Description
DataFrame

A new DataFrame containing the union of all input DataFrames

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df1 = spark.sql('SELECT 1 as a')
>>> df2 = spark.sql('SELECT 2 as a')
>>> df3 = spark.sql('SELECT 3 as a')
>>> union_dataframes(df1, df2, df3).show()
+---+
|  a|
+---+
|  1|
|  2|
|  3|
+---+

>>> df1.transform(union_dataframes, df2, df3).show()
+---+
|  a|
+---+
|  1|
|  2|
|  3|
+---+
Source code in spark_frame/transformations_impl/union_dataframes.py
def union_dataframes(*dfs: DataFrame) -> DataFrame:
    """Returns the union between multiple DataFrames

    Args:
        dfs: One or more Spark DataFrames

    Returns:
        A new DataFrame containing the union of all input DataFrames

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df1 = spark.sql('SELECT 1 as a')
        >>> df2 = spark.sql('SELECT 2 as a')
        >>> df3 = spark.sql('SELECT 3 as a')
        >>> union_dataframes(df1, df2, df3).show()
        +---+
        |  a|
        +---+
        |  1|
        |  2|
        |  3|
        +---+
        <BLANKLINE>
        >>> df1.transform(union_dataframes, df2, df3).show()
        +---+
        |  a|
        +---+
        |  1|
        |  2|
        |  3|
        +---+
        <BLANKLINE>
    """
    assert_true(len(dfs) > 0, ValueError("Input list is empty"))
    res = dfs[0]
    for df in dfs[1:]:
        res = res.union(df)
    return res

unpivot(df: DataFrame, pivot_columns: List[str], key_alias: str = 'key', value_alias: str = 'value') -> DataFrame

Unpivot the given DataFrame along the specified pivot columns. All columns that are not pivot columns should have the same type.

This is the inverse transformation of the pyspark.sql.GroupedData.pivot operation.

Parameters:

Name Type Description Default
df DataFrame

A DataFrame

required
pivot_columns List[str]

The list of columns names on which to perform the pivot

required
key_alias str

Alias given to the 'key' column

'key'
value_alias str

Alias given to the 'value' column

'value'

Returns:

Type Description
DataFrame

An unpivotted DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame([
...    (2018, "Orange",  None, 4000, None),
...    (2018, "Beans",   None, 1500, 2000),
...    (2018, "Banana",  2000,  400, None),
...    (2018, "Carrots", 2000, 1200, None),
...    (2019, "Orange",  5000, None, 5000),
...    (2019, "Beans",   None, 1500, 2000),
...    (2019, "Banana",  None, 1400,  400),
...    (2019, "Carrots", None,  200, None),
...  ], "year INT, product STRING, Canada INT, China INT, Mexico INT"
... )
>>> df.show()
+----+-------+------+-----+------+
|year|product|Canada|China|Mexico|
+----+-------+------+-----+------+
|2018| Orange|  NULL| 4000|  NULL|
|2018|  Beans|  NULL| 1500|  2000|
|2018| Banana|  2000|  400|  NULL|
|2018|Carrots|  2000| 1200|  NULL|
|2019| Orange|  5000| NULL|  5000|
|2019|  Beans|  NULL| 1500|  2000|
|2019| Banana|  NULL| 1400|   400|
|2019|Carrots|  NULL|  200|  NULL|
+----+-------+------+-----+------+

>>> unpivot(df, ['year', 'product'], key_alias='country', value_alias='total').show(100)
+----+-------+-------+-----+
|year|product|country|total|
+----+-------+-------+-----+
|2018| Orange| Canada| NULL|
|2018| Orange|  China| 4000|
|2018| Orange| Mexico| NULL|
|2018|  Beans| Canada| NULL|
|2018|  Beans|  China| 1500|
|2018|  Beans| Mexico| 2000|
|2018| Banana| Canada| 2000|
|2018| Banana|  China|  400|
|2018| Banana| Mexico| NULL|
|2018|Carrots| Canada| 2000|
|2018|Carrots|  China| 1200|
|2018|Carrots| Mexico| NULL|
|2019| Orange| Canada| 5000|
|2019| Orange|  China| NULL|
|2019| Orange| Mexico| 5000|
|2019|  Beans| Canada| NULL|
|2019|  Beans|  China| 1500|
|2019|  Beans| Mexico| 2000|
|2019| Banana| Canada| NULL|
|2019| Banana|  China| 1400|
|2019| Banana| Mexico|  400|
|2019|Carrots| Canada| NULL|
|2019|Carrots|  China|  200|
|2019|Carrots| Mexico| NULL|
+----+-------+-------+-----+
Source code in spark_frame/transformations_impl/unpivot.py
def unpivot(df: DataFrame, pivot_columns: List[str], key_alias: str = "key", value_alias: str = "value") -> DataFrame:
    """Unpivot the given DataFrame along the specified pivot columns.
    All columns that are not pivot columns should have the same type.

    This is the inverse transformation of the [pyspark.sql.GroupedData.pivot][] operation.

    Args:
        df: A DataFrame
        pivot_columns: The list of columns names on which to perform the pivot
        key_alias: Alias given to the 'key' column
        value_alias: Alias given to the 'value' column

    Returns:
        An unpivotted DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame([
        ...    (2018, "Orange",  None, 4000, None),
        ...    (2018, "Beans",   None, 1500, 2000),
        ...    (2018, "Banana",  2000,  400, None),
        ...    (2018, "Carrots", 2000, 1200, None),
        ...    (2019, "Orange",  5000, None, 5000),
        ...    (2019, "Beans",   None, 1500, 2000),
        ...    (2019, "Banana",  None, 1400,  400),
        ...    (2019, "Carrots", None,  200, None),
        ...  ], "year INT, product STRING, Canada INT, China INT, Mexico INT"
        ... )
        >>> df.show()
        +----+-------+------+-----+------+
        |year|product|Canada|China|Mexico|
        +----+-------+------+-----+------+
        |2018| Orange|  NULL| 4000|  NULL|
        |2018|  Beans|  NULL| 1500|  2000|
        |2018| Banana|  2000|  400|  NULL|
        |2018|Carrots|  2000| 1200|  NULL|
        |2019| Orange|  5000| NULL|  5000|
        |2019|  Beans|  NULL| 1500|  2000|
        |2019| Banana|  NULL| 1400|   400|
        |2019|Carrots|  NULL|  200|  NULL|
        +----+-------+------+-----+------+
        <BLANKLINE>
        >>> unpivot(df, ['year', 'product'], key_alias='country', value_alias='total').show(100)
        +----+-------+-------+-----+
        |year|product|country|total|
        +----+-------+-------+-----+
        |2018| Orange| Canada| NULL|
        |2018| Orange|  China| 4000|
        |2018| Orange| Mexico| NULL|
        |2018|  Beans| Canada| NULL|
        |2018|  Beans|  China| 1500|
        |2018|  Beans| Mexico| 2000|
        |2018| Banana| Canada| 2000|
        |2018| Banana|  China|  400|
        |2018| Banana| Mexico| NULL|
        |2018|Carrots| Canada| 2000|
        |2018|Carrots|  China| 1200|
        |2018|Carrots| Mexico| NULL|
        |2019| Orange| Canada| 5000|
        |2019| Orange|  China| NULL|
        |2019| Orange| Mexico| 5000|
        |2019|  Beans| Canada| NULL|
        |2019|  Beans|  China| 1500|
        |2019|  Beans| Mexico| 2000|
        |2019| Banana| Canada| NULL|
        |2019| Banana|  China| 1400|
        |2019| Banana| Mexico|  400|
        |2019|Carrots| Canada| NULL|
        |2019|Carrots|  China|  200|
        |2019|Carrots| Mexico| NULL|
        +----+-------+-------+-----+
        <BLANKLINE>
    """
    pivoted_columns = [(c, t) for (c, t) in df.dtypes if c not in pivot_columns]
    cols, types = zip(*pivoted_columns)

    # Check that all columns have the same type.
    assert_true(
        len(set(types)) == 1,
        ("All pivoted columns should be of the same type:\n Pivoted columns are: %s" % pivoted_columns),
    )

    # Create and explode an array of (column_name, column_value) structs
    kvs = f.explode(
        f.array(*[f.struct(f.lit(c).alias(key_alias), f.col(quote(c)).alias(value_alias)) for c in cols]),
    ).alias("kvs")

    return df.select([f.col(c) for c in quote_columns(pivot_columns)] + [kvs]).select(
        [*quote_columns(pivot_columns), "kvs.*"],
    )

with_generic_typed_struct(df: DataFrame, col_names: List[str]) -> DataFrame

Transform the specified struct columns of a given Dataframe into generic typed struct columns with the following generic schema (based on https://spark.apache.org/docs/latest/sql-ref-datatypes.html) :

STRUCT<
    key: STRING, -- (name of the field inside the struct)
    type: STRING, -- (type of the field inside the struct)
    value: STRUCT< -- (all the fields will be null except for the one with the correct type)
        date: DATE,
        timestamp: TIMESTAMP,
        int: LONG,
        float: DOUBLE,
        boolean: BOOLEAN,
        string: STRING,
        bytes: BINARY
    >
>

The following spark types will be automatically cast into the more generic following types:

  • tinyint, smallint, int -> bigint
  • float, decimal -> double

Parameters:

Name Type Description Default
df DataFrame

The Dataframe to transform

required
col_names List[str]

A list of column names to transform

required

Returns:

Type Description
DataFrame

A Dataframe with the columns transformed into generic typed structs

Limitations

Currently, complex field types (structs, maps, arrays) are not supported. All fields of the struct columns to convert must be of basic types.

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.createDataFrame(
...     [(1, {"first.name": "Jacques", "age": 25, "is.an.adult": True}),
...      (2, {"first.name": "Michel", "age": 12, "is.an.adult": False}),
...      (3, {"first.name": "Marie", "age": 36, "is.an.adult": True})],
...     "id INT, `person.struct` STRUCT<`first.name`:STRING, age:INT, `is.an.adult`:BOOLEAN>"
... )
>>> df.show(truncate=False)
+---+-------------------+
|id |person.struct      |
+---+-------------------+
|1  |{Jacques, 25, true}|
|2  |{Michel, 12, false}|
|3  |{Marie, 36, true}  |
+---+-------------------+

>>> res = with_generic_typed_struct(df, ["`person.struct`"])
>>> res.printSchema()
root
 |-- id: integer (nullable = true)
 |-- person.struct: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- key: string (nullable = false)
 |    |    |-- type: string (nullable = false)
 |    |    |-- value: struct (nullable = false)
 |    |    |    |-- boolean: boolean (nullable = true)
 |    |    |    |-- bytes: binary (nullable = true)
 |    |    |    |-- date: date (nullable = true)
 |    |    |    |-- float: double (nullable = true)
 |    |    |    |-- int: long (nullable = true)
 |    |    |    |-- string: string (nullable = true)
 |    |    |    |-- timestamp: timestamp (nullable = true)

>>> res.show(10, False)
+---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|id |person.struct                                                                                                                                                                                  |
+---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Jacques, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 25, NULL, NULL}}, {is.an.adult, boolean, {true, NULL, NULL, NULL, NULL, NULL, NULL}}]|
|2  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Michel, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 12, NULL, NULL}}, {is.an.adult, boolean, {false, NULL, NULL, NULL, NULL, NULL, NULL}}]|
|3  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Marie, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 36, NULL, NULL}}, {is.an.adult, boolean, {true, NULL, NULL, NULL, NULL, NULL, NULL}}]  |
+---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
Source code in spark_frame/transformations_impl/with_generic_typed_struct.py
def with_generic_typed_struct(df: DataFrame, col_names: List[str]) -> DataFrame:
    """Transform the specified struct columns of a given [Dataframe][pyspark.sql.DataFrame] into
    generic typed struct columns with the following generic schema
    (based on [https://spark.apache.org/docs/latest/sql-ref-datatypes.html](
    https://spark.apache.org/docs/latest/sql-ref-datatypes.html)) :

        STRUCT<
            key: STRING, -- (name of the field inside the struct)
            type: STRING, -- (type of the field inside the struct)
            value: STRUCT< -- (all the fields will be null except for the one with the correct type)
                date: DATE,
                timestamp: TIMESTAMP,
                int: LONG,
                float: DOUBLE,
                boolean: BOOLEAN,
                string: STRING,
                bytes: BINARY
            >
        >

    The following spark types will be automatically cast into the more generic following types:

    - `tinyint`, `smallint`, `int` -> `bigint`
    - `float`, `decimal` -> `double`

    Args:
        df: The Dataframe to transform
        col_names: A list of column names to transform

    Returns:
        A Dataframe with the columns transformed into generic typed structs

    !!! warning "Limitations"
        Currently, complex field types (structs, maps, arrays) are not supported.
        All fields of the struct columns to convert must be of basic types.

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.createDataFrame(
        ...     [(1, {"first.name": "Jacques", "age": 25, "is.an.adult": True}),
        ...      (2, {"first.name": "Michel", "age": 12, "is.an.adult": False}),
        ...      (3, {"first.name": "Marie", "age": 36, "is.an.adult": True})],
        ...     "id INT, `person.struct` STRUCT<`first.name`:STRING, age:INT, `is.an.adult`:BOOLEAN>"
        ... )
        >>> df.show(truncate=False)
        +---+-------------------+
        |id |person.struct      |
        +---+-------------------+
        |1  |{Jacques, 25, true}|
        |2  |{Michel, 12, false}|
        |3  |{Marie, 36, true}  |
        +---+-------------------+
        <BLANKLINE>
        >>> res = with_generic_typed_struct(df, ["`person.struct`"])
        >>> res.printSchema()
        root
         |-- id: integer (nullable = true)
         |-- person.struct: array (nullable = false)
         |    |-- element: struct (containsNull = false)
         |    |    |-- key: string (nullable = false)
         |    |    |-- type: string (nullable = false)
         |    |    |-- value: struct (nullable = false)
         |    |    |    |-- boolean: boolean (nullable = true)
         |    |    |    |-- bytes: binary (nullable = true)
         |    |    |    |-- date: date (nullable = true)
         |    |    |    |-- float: double (nullable = true)
         |    |    |    |-- int: long (nullable = true)
         |    |    |    |-- string: string (nullable = true)
         |    |    |    |-- timestamp: timestamp (nullable = true)
        <BLANKLINE>
        >>> res.show(10, False)
        +---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
        |id |person.struct                                                                                                                                                                                  |
        +---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
        |1  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Jacques, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 25, NULL, NULL}}, {is.an.adult, boolean, {true, NULL, NULL, NULL, NULL, NULL, NULL}}]|
        |2  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Michel, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 12, NULL, NULL}}, {is.an.adult, boolean, {false, NULL, NULL, NULL, NULL, NULL, NULL}}]|
        |3  |[{first.name, string, {NULL, NULL, NULL, NULL, NULL, Marie, NULL}}, {age, int, {NULL, NULL, NULL, NULL, 36, NULL, NULL}}, {is.an.adult, boolean, {true, NULL, NULL, NULL, NULL, NULL, NULL}}]  |
        +---+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
        <BLANKLINE>
    """  # noqa: E501

    source_to_cast = {
        "date": "date",
        "timestamp": "timestamp",
        "tinyint": "bigint",
        "smallint": "bigint",
        "int": "bigint",
        "bigint": "bigint",
        "float": "double",
        "double": "double",
        "boolean": "boolean",
        "string": "string",
        "binary": "binary",
    }
    """Mapping indicating for each source Spark DataTypes the type into which it will be cast."""

    cast_to_name = {
        "binary": "bytes",
        "bigint": "int",
        "double": "float",
    }
    """Mapping indicating for each already cast Spark DataTypes the name of the corresponding field.
    When missing, the same name will be kept."""

    name_cast = {cast_to_name.get(value, value): value for value in source_to_cast.values()}
    # We make sure the types are sorted
    name_cast = dict(sorted(name_cast.items()))

    def match_regex_types(source_type: str) -> Optional[str]:
        """Matches the source types against regexes to identify more complex types (like Decimal(x, y))"""
        regex_to_cast_types = [(re.compile("decimal(.*)"), "float")]
        for regex, cast_type in regex_to_cast_types:
            if regex.match(source_type) is not None:
                return cast_type
        return None

    def field_to_col(field: StructField, column_name: str) -> Optional[Column]:
        """Transforms the specified field into a generic column"""
        source_type = field.dataType.simpleString()
        cast_type = source_to_cast.get(source_type)
        field_name = column_name + "." + quote(field.name)
        if cast_type is None:
            cast_type = match_regex_types(source_type)
        if cast_type is None:
            print(
                "WARNING: The field {field_name} is of type {source_type} which is currently unsupported. "
                "This field will be dropped.".format(
                    field_name=field_name,
                    source_type=source_type,
                ),
            )
            return None
        name_type = cast_to_name.get(cast_type, cast_type)
        return f.struct(
            f.lit(field.name).alias("key"),
            f.lit(name_type).alias("type"),
            # In the code below, we use f.expr instead of f.col because it looks like f.col
            # does not support column names with backquotes in them, but f.expr does :-p
            f.struct(
                *[
                    (f.expr(field_name) if name_type == name_t else f.lit(None)).astype(cast_t).alias(name_t)
                    for name_t, cast_t in name_cast.items()
                ],
            ).alias("value"),
        )

    for col_name in col_names:
        schema = _get_nested_col_type_from_schema(col_name, df.schema)
        assert_true(isinstance(schema, StructType))
        schema = cast(StructType, schema)
        columns = [field_to_col(field, col_name) for field in schema.fields]
        columns_2 = [col for col in columns if col is not None]
        df = df.withColumn(unquote(col_name), f.array(*columns_2).alias("values"))
    return df