Skip to content

spark_frame.nested

Please read this before using the spark_frame.nested module

The spark_frame.nested module contains several methods that make the manipulation of deeply nested data structures much easier. Before diving into it, it is important to explicit the concept of Field in the context of this library.

First, let's distinguish the notion of Column and Field. Both terms are already used in Spark, but we chose here to make the following distinction:

  • A Column is a root-level column of a DataFrame.
  • A Field is any column or sub-column inside a struct of the DataFrame.

Example: let's consider the following DataFrame

>>> from spark_frame.examples.reference_nested import _get_sample_data
>>> df = _get_sample_data()
>>> df.show(truncate=False)  # noqa: E501
+---+-----------------------+---------------+
|id |name                   |types          |
+---+-----------------------+---------------+
|1  |{Bulbasaur, Bulbizarre}|[Grass, Poison]|
+---+-----------------------+---------------+

>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- name: struct (nullable = false)
 |    |-- english: string (nullable = false)
 |    |-- french: string (nullable = false)
 |-- types: array (nullable = false)
 |    |-- element: string (containsNull = false)

This DataFrame has 3 columns:

id
name
types

But it has 4 fields:

id
name.english
name.french
types!

This can be seen by using the method spark_frame.nested.print_schema

>>> from spark_frame import nested
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- name.english: string (nullable = false)
 |-- name.french: string (nullable = false)
 |-- types!: string (nullable = false)

As we can see, some field names contain dots . or exclamation marks !, they convey the following meaning:

  • A dot . represents a struct.
  • An exclamation mark ! represents an array.

While the dot syntax for structs should feel familiar to users, the exclamation mark ! should feel new. It is used as a repetition marker indicating that this field is repeated.

Tip

It is important to not forget to use exclamation marks ! when mentionning a field. For instance:

  • types designates the root-level field which is of type ARRAY<STRING>
  • types! designates the elements inside this array

In particular, if a field "my_field" is of type ARRAY<ARRAY<STRING>>, the innermost elements of the arrays will be designated as "my_field!!" with two exclamation marks.

Limitation: Do not use dots, exclamation marks or percents in field names

Given the syntax used, every method defined in the spark_frame.nested module assumes that all field names in DataFrames do not contain any dot ., exclamation mark ! or percents %. This can be worked around using the transformation spark_frame.transformations.transform_all_field_names.


print_schema(df: DataFrame) -> None

Print the DataFrame's flattened schema to the standard output.

  • Structs are flattened with a . after their name.
  • Arrays are flattened with a ! character after their name.
  • Maps are flattened with a %key and '%value' after their name.

Limitation: Dots, percents, and exclamation marks are not supported in field names

Given the syntax used, every method defined in the spark_frame.nested module assumes that all field names in DataFrames do not contain any dot ., percent % or exclamation mark !. This can be worked around using the transformation spark_frame.transformations.transform_all_field_names.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required

Examples:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT
...     1 as id,
...     ARRAY(STRUCT(2 as a, ARRAY(STRUCT(3 as c, 4 as d)) as b, ARRAY(5, 6) as e)) as s1,
...     STRUCT(7 as f) as s2,
...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s3,
...     ARRAY(ARRAY(STRUCT(1 as e, 2 as f)), ARRAY(STRUCT(3 as e, 4 as f))) as s4,
...     MAP(STRUCT(1 as a), STRUCT(2 as b)) as m1
... ''')
>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- s1: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- a: integer (nullable = false)
 |    |    |-- b: array (nullable = false)
 |    |    |    |-- element: struct (containsNull = false)
 |    |    |    |    |-- c: integer (nullable = false)
 |    |    |    |    |-- d: integer (nullable = false)
 |    |    |-- e: array (nullable = false)
 |    |    |    |-- element: integer (containsNull = false)
 |-- s2: struct (nullable = false)
 |    |-- f: integer (nullable = false)
 |-- s3: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: integer (containsNull = false)
 |-- s4: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: struct (containsNull = false)
 |    |    |    |-- e: integer (nullable = false)
 |    |    |    |-- f: integer (nullable = false)
 |-- m1: map (nullable = false)
 |    |-- key: struct
 |    |    |-- a: integer (nullable = false)
 |    |-- value: struct (valueContainsNull = false)
 |    |    |-- b: integer (nullable = false)

>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s1!.a: integer (nullable = false)
 |-- s1!.b!.c: integer (nullable = false)
 |-- s1!.b!.d: integer (nullable = false)
 |-- s1!.e!: integer (nullable = false)
 |-- s2.f: integer (nullable = false)
 |-- s3!!: integer (nullable = false)
 |-- s4!!.e: integer (nullable = false)
 |-- s4!!.f: integer (nullable = false)
 |-- m1%key.a: integer (nullable = false)
 |-- m1%value.b: integer (nullable = false)
Source code in spark_frame/nested_impl/print_schema.py
def print_schema(df: DataFrame) -> None:
    """Print the DataFrame's flattened schema to the standard output.

    - Structs are flattened with a `.` after their name.
    - Arrays are flattened with a `!` character after their name.
    - Maps are flattened with a `%key` and '%value' after their name.

    !!! warning "Limitation: Dots, percents, and exclamation marks are not supported in field names"
        Given the syntax used, every method defined in the `spark_frame.nested` module assumes that all field
        names in DataFrames do not contain any dot `.`, percent `%` or exclamation mark `!`.
        This can be worked around using the transformation
        [`spark_frame.transformations.transform_all_field_names`]
        [spark_frame.transformations_impl.transform_all_field_names.transform_all_field_names].

    Args:
        df: A Spark DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT
        ...     1 as id,
        ...     ARRAY(STRUCT(2 as a, ARRAY(STRUCT(3 as c, 4 as d)) as b, ARRAY(5, 6) as e)) as s1,
        ...     STRUCT(7 as f) as s2,
        ...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s3,
        ...     ARRAY(ARRAY(STRUCT(1 as e, 2 as f)), ARRAY(STRUCT(3 as e, 4 as f))) as s4,
        ...     MAP(STRUCT(1 as a), STRUCT(2 as b)) as m1
        ... ''')
        >>> df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- s1: array (nullable = false)
         |    |-- element: struct (containsNull = false)
         |    |    |-- a: integer (nullable = false)
         |    |    |-- b: array (nullable = false)
         |    |    |    |-- element: struct (containsNull = false)
         |    |    |    |    |-- c: integer (nullable = false)
         |    |    |    |    |-- d: integer (nullable = false)
         |    |    |-- e: array (nullable = false)
         |    |    |    |-- element: integer (containsNull = false)
         |-- s2: struct (nullable = false)
         |    |-- f: integer (nullable = false)
         |-- s3: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: integer (containsNull = false)
         |-- s4: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: struct (containsNull = false)
         |    |    |    |-- e: integer (nullable = false)
         |    |    |    |-- f: integer (nullable = false)
         |-- m1: map (nullable = false)
         |    |-- key: struct
         |    |    |-- a: integer (nullable = false)
         |    |-- value: struct (valueContainsNull = false)
         |    |    |-- b: integer (nullable = false)
        <BLANKLINE>
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s1!.a: integer (nullable = false)
         |-- s1!.b!.c: integer (nullable = false)
         |-- s1!.b!.d: integer (nullable = false)
         |-- s1!.e!: integer (nullable = false)
         |-- s2.f: integer (nullable = false)
         |-- s3!!: integer (nullable = false)
         |-- s4!!.e: integer (nullable = false)
         |-- s4!!.f: integer (nullable = false)
         |-- m1%key.a: integer (nullable = false)
         |-- m1%value.b: integer (nullable = false)
        <BLANKLINE>
    """
    print(schema_string(df))

select(df: DataFrame, fields: Mapping[str, ColumnTransformation]) -> DataFrame

Project a set of expressions and returns a new DataFrame.

This method is similar to the DataFrame.select method, with the extra capability of working on nested and repeated fields (structs and arrays).

The syntax for field names works as follows:

  • "." is the separator for struct elements
  • "!" must be appended at the end of fields that are repeated (arrays)
  • Map keys are appended with %key
  • Map values are appended with %value

The following types of transformation are allowed:

  • String and column expressions can be used on any non-repeated field, even nested ones.
  • When working on repeated fields, transformations must be expressed as higher order functions (e.g. lambda expressions). String and column expressions can be used on repeated fields as well, but their value will be repeated multiple times.
  • When working on multiple levels of nested arrays, higher order functions may take multiple arguments, corresponding to each level of repetition (See Example 5.).
  • None can also be used to represent the identity transformation, this is useful to select a field without changing and without having to repeat its name.

Limitation: Dots, percents, and exclamation marks are not supported in field names

Given the syntax used, every method defined in the spark_frame.nested module assumes that all field names in DataFrames do not contain any dot ., percent % or exclamation mark !. This can be worked around using the transformation spark_frame.transformations.transform_all_field_names.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
fields Mapping[str, ColumnTransformation]

A Dict(field_name, transformation_to_apply)

required

Returns:

Type Description
DataFrame

A new DataFrame where only the specified field have been selected and the corresponding

DataFrame

transformations were applied to each of them.

Example 1: non-repeated fields

>>> from pyspark.sql import SparkSession
>>> from pyspark.sql import functions as f
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''')
>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- s: struct (nullable = false)
 |    |-- a: integer (nullable = false)
 |    |-- b: integer (nullable = false)

>>> df.show()
+---+------+
| id|     s|
+---+------+
|  1|{2, 3}|
+---+------+

Transformations on non-repeated fields may be expressed as a string representing a column name, a Column expression or None. (In this example the column "id" will be dropped because it was not selected)

>>> new_df = nested.select(df, {
...     "s.a": "s.a",                        # Column name (string)
...     "s.b": None,                         # None: use to keep a column without having to repeat its name
...     "s.c": f.col("s.a") + f.col("s.b")   # Column expression
... })
>>> new_df.printSchema()
root
 |-- s: struct (nullable = false)
 |    |-- a: integer (nullable = false)
 |    |-- b: integer (nullable = false)
 |    |-- c: integer (nullable = false)

>>> new_df.show()
+---------+
|        s|
+---------+
|{2, 3, 5}|
+---------+

Example 2: repeated fields

>>> df = spark.sql('SELECT 1 as id, ARRAY(STRUCT(1 as a, 2 as b), STRUCT(3 as a, 4 as b)) as s')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s!.a: integer (nullable = false)
 |-- s!.b: integer (nullable = false)

>>> df.show()
+---+----------------+
| id|               s|
+---+----------------+
|  1|[{1, 2}, {3, 4}]|
+---+----------------+

Transformations on repeated fields must be expressed as higher-order functions (lambda expressions or named functions). The value passed to this function will correspond to the last repeated element.

>>> df.transform(nested.select, {
...     "s!.a": lambda s: s["a"],
...     "s!.b": None,
...     "s!.c": lambda s: s["a"] + s["b"]
... }).show(truncate=False)
+----------------------+
|s                     |
+----------------------+
|[{1, 2, 3}, {3, 4, 7}]|
+----------------------+

String and column expressions can be used on repeated fields as well, but their value will be repeated multiple times.

>>> df.transform(nested.select, {
...     "id": None,
...     "s!.a": "id",
...     "s!.b": f.lit(2)
... }).show(truncate=False)
+---+----------------+
|id |s               |
+---+----------------+
|1  |[{1, 2}, {1, 2}]|
+---+----------------+

Example 3: field repeated twice

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(STRUCT(ARRAY(1, 2, 3) as e)) as s1,
...         ARRAY(STRUCT(ARRAY(4, 5, 6) as e)) as s2
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s1!.e!: integer (nullable = false)
 |-- s2!.e!: integer (nullable = false)

>>> df.show()
+---+-------------+-------------+
| id|           s1|           s2|
+---+-------------+-------------+
|  1|[{[1, 2, 3]}]|[{[4, 5, 6]}]|
+---+-------------+-------------+

Here, the lambda expression will be applied to the last repeated element e.

>>> new_df = df.transform(nested.select, {
...  "s1!.e!": None,
...  "s2!.e!": lambda e : e.cast("DOUBLE")
... })
>>> nested.print_schema(new_df)
root
 |-- s1!.e!: integer (nullable = false)
 |-- s2!.e!: double (nullable = false)

>>> new_df.show()
+-------------+-------------------+
|           s1|                 s2|
+-------------+-------------------+
|[{[1, 2, 3]}]|[{[4.0, 5.0, 6.0]}]|
+-------------+-------------------+

Example 4: Dataframe with maps

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         MAP("a", STRUCT(2 as a, 3 as b)) as m1
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- m1%key: string (nullable = false)
 |-- m1%value.a: integer (nullable = false)
 |-- m1%value.b: integer (nullable = false)

>>> df.show()
+---+-------------+
| id|           m1|
+---+-------------+
|  1|{a -> {2, 3}}|
+---+-------------+
>>> new_df = df.transform(nested.select, {
...  "id": None,
...  "m1%key": lambda key : f.upper(key),
...  "m1%value.a": lambda value : value["a"].cast("DOUBLE")
... })
>>> nested.print_schema(new_df)
root
 |-- id: integer (nullable = false)
 |-- m1%key: string (nullable = false)
 |-- m1%value.a: double (nullable = false)

>>> new_df.show()
+---+------------+
| id|          m1|
+---+------------+
|  1|{A -> {2.0}}|
+---+------------+

Example 5: Accessing multiple repetition levels

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(
...             STRUCT(2 as average, ARRAY(1, 2, 3) as values),
...             STRUCT(3 as average, ARRAY(1, 2, 3, 4, 5) as values)
...         ) as s1
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s1!.average: integer (nullable = false)
 |-- s1!.values!: integer (nullable = false)

>>> df.show(truncate=False)
+---+--------------------------------------+
|id |s1                                    |
+---+--------------------------------------+
|1  |[{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}]|
+---+--------------------------------------+

Here, the transformation applied to "s1!.values!" takes two arguments.

>>> new_df = df.transform(nested.select, {
...  "id": None,
...  "s1!.average": None,
...  "s1!.values!": lambda s1, value : value - s1["average"]
... })
>>> new_df.show(truncate=False)
+---+-----------------------------------------+
|id |s1                                       |
+---+-----------------------------------------+
|1  |[{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}]|
+---+-----------------------------------------+

Extra arguments can be added to the left for each repetition level, up to the root level.

>>> new_df = df.transform(nested.select, {
...  "id": None,
...  "s1!.average": None,
...  "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"]
... })
>>> new_df.show(truncate=False)
+---+---------------------------------------+
|id |s1                                     |
+---+---------------------------------------+
|1  |[{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}]|
+---+---------------------------------------+
Source code in spark_frame/nested_impl/select_impl.py
def select(df: DataFrame, fields: Mapping[str, ColumnTransformation]) -> DataFrame:
    """Project a set of expressions and returns a new [DataFrame][pyspark.sql.DataFrame].

    This method is similar to the [DataFrame.select][pyspark.sql.DataFrame.select] method, with the extra
    capability of working on nested and repeated fields (structs and arrays).

    The syntax for field names works as follows:

    - "." is the separator for struct elements
    - "!" must be appended at the end of fields that are repeated (arrays)
    - Map keys are appended with `%key`
    - Map values are appended with `%value`

    The following types of transformation are allowed:

    - String and column expressions can be used on any non-repeated field, even nested ones.
    - When working on repeated fields, transformations must be expressed as higher order functions
      (e.g. lambda expressions). String and column expressions can be used on repeated fields as well,
      but their value will be repeated multiple times.
    - When working on multiple levels of nested arrays, higher order functions may take multiple arguments,
      corresponding to each level of repetition (See Example 5.).
    - `None` can also be used to represent the identity transformation, this is useful to select a field without
       changing and without having to repeat its name.

    !!! warning "Limitation: Dots, percents, and exclamation marks are not supported in field names"
        Given the syntax used, every method defined in the `spark_frame.nested` module assumes that all field
        names in DataFrames do not contain any dot `.`, percent `%` or exclamation mark `!`.
        This can be worked around using the transformation
        [`spark_frame.transformations.transform_all_field_names`]
        [spark_frame.transformations_impl.transform_all_field_names.transform_all_field_names].

    Args:
        df: A Spark DataFrame
        fields: A Dict(field_name, transformation_to_apply)

    Returns:
        A new DataFrame where only the specified field have been selected and the corresponding
        transformations were applied to each of them.

    Examples: Example 1: non-repeated fields
        >>> from pyspark.sql import SparkSession
        >>> from pyspark.sql import functions as f
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''')
        >>> df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- s: struct (nullable = false)
         |    |-- a: integer (nullable = false)
         |    |-- b: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+------+
        | id|     s|
        +---+------+
        |  1|{2, 3}|
        +---+------+
        <BLANKLINE>

        Transformations on non-repeated fields may be expressed as a string representing a column name,
        a Column expression or None.
        (In this example the column "id" will be dropped because it was not selected)
        >>> new_df = nested.select(df, {
        ...     "s.a": "s.a",                        # Column name (string)
        ...     "s.b": None,                         # None: use to keep a column without having to repeat its name
        ...     "s.c": f.col("s.a") + f.col("s.b")   # Column expression
        ... })
        >>> new_df.printSchema()
        root
         |-- s: struct (nullable = false)
         |    |-- a: integer (nullable = false)
         |    |-- b: integer (nullable = false)
         |    |-- c: integer (nullable = false)
        <BLANKLINE>
        >>> new_df.show()
        +---------+
        |        s|
        +---------+
        |{2, 3, 5}|
        +---------+
        <BLANKLINE>

    Examples: Example 2: repeated fields
        >>> df = spark.sql('SELECT 1 as id, ARRAY(STRUCT(1 as a, 2 as b), STRUCT(3 as a, 4 as b)) as s')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s!.a: integer (nullable = false)
         |-- s!.b: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+----------------+
        | id|               s|
        +---+----------------+
        |  1|[{1, 2}, {3, 4}]|
        +---+----------------+
        <BLANKLINE>

        Transformations on repeated fields must be expressed as higher-order
        functions (lambda expressions or named functions).
        The value passed to this function will correspond to the last repeated element.
        >>> df.transform(nested.select, {
        ...     "s!.a": lambda s: s["a"],
        ...     "s!.b": None,
        ...     "s!.c": lambda s: s["a"] + s["b"]
        ... }).show(truncate=False)
        +----------------------+
        |s                     |
        +----------------------+
        |[{1, 2, 3}, {3, 4, 7}]|
        +----------------------+
        <BLANKLINE>

        String and column expressions can be used on repeated fields as well,
        but their value will be repeated multiple times.
        >>> df.transform(nested.select, {
        ...     "id": None,
        ...     "s!.a": "id",
        ...     "s!.b": f.lit(2)
        ... }).show(truncate=False)
        +---+----------------+
        |id |s               |
        +---+----------------+
        |1  |[{1, 2}, {1, 2}]|
        +---+----------------+
        <BLANKLINE>

    Examples: Example 3: field repeated twice
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(STRUCT(ARRAY(1, 2, 3) as e)) as s1,
        ...         ARRAY(STRUCT(ARRAY(4, 5, 6) as e)) as s2
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s1!.e!: integer (nullable = false)
         |-- s2!.e!: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+-------------+-------------+
        | id|           s1|           s2|
        +---+-------------+-------------+
        |  1|[{[1, 2, 3]}]|[{[4, 5, 6]}]|
        +---+-------------+-------------+
        <BLANKLINE>

        Here, the lambda expression will be applied to the last repeated element `e`.
        >>> new_df = df.transform(nested.select, {
        ...  "s1!.e!": None,
        ...  "s2!.e!": lambda e : e.cast("DOUBLE")
        ... })
        >>> nested.print_schema(new_df)
        root
         |-- s1!.e!: integer (nullable = false)
         |-- s2!.e!: double (nullable = false)
        <BLANKLINE>
        >>> new_df.show()
        +-------------+-------------------+
        |           s1|                 s2|
        +-------------+-------------------+
        |[{[1, 2, 3]}]|[{[4.0, 5.0, 6.0]}]|
        +-------------+-------------------+
        <BLANKLINE>

    Examples: Example 4: Dataframe with maps
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         MAP("a", STRUCT(2 as a, 3 as b)) as m1
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- m1%key: string (nullable = false)
         |-- m1%value.a: integer (nullable = false)
         |-- m1%value.b: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+-------------+
        | id|           m1|
        +---+-------------+
        |  1|{a -> {2, 3}}|
        +---+-------------+
        <BLANKLINE>

        >>> new_df = df.transform(nested.select, {
        ...  "id": None,
        ...  "m1%key": lambda key : f.upper(key),
        ...  "m1%value.a": lambda value : value["a"].cast("DOUBLE")
        ... })
        >>> nested.print_schema(new_df)
        root
         |-- id: integer (nullable = false)
         |-- m1%key: string (nullable = false)
         |-- m1%value.a: double (nullable = false)
        <BLANKLINE>
        >>> new_df.show()
        +---+------------+
        | id|          m1|
        +---+------------+
        |  1|{A -> {2.0}}|
        +---+------------+
        <BLANKLINE>

    Examples: Example 5: Accessing multiple repetition levels
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(
        ...             STRUCT(2 as average, ARRAY(1, 2, 3) as values),
        ...             STRUCT(3 as average, ARRAY(1, 2, 3, 4, 5) as values)
        ...         ) as s1
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s1!.average: integer (nullable = false)
         |-- s1!.values!: integer (nullable = false)
        <BLANKLINE>
        >>> df.show(truncate=False)
        +---+--------------------------------------+
        |id |s1                                    |
        +---+--------------------------------------+
        |1  |[{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}]|
        +---+--------------------------------------+
        <BLANKLINE>

        Here, the transformation applied to "s1!.values!" takes two arguments.
        >>> new_df = df.transform(nested.select, {
        ...  "id": None,
        ...  "s1!.average": None,
        ...  "s1!.values!": lambda s1, value : value - s1["average"]
        ... })
        >>> new_df.show(truncate=False)
        +---+-----------------------------------------+
        |id |s1                                       |
        +---+-----------------------------------------+
        |1  |[{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}]|
        +---+-----------------------------------------+
        <BLANKLINE>

        Extra arguments can be added to the left for each repetition level, up to the root level.
        >>> new_df = df.transform(nested.select, {
        ...  "id": None,
        ...  "s1!.average": None,
        ...  "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"]
        ... })
        >>> new_df.show(truncate=False)
        +---+---------------------------------------+
        |id |s1                                     |
        +---+---------------------------------------+
        |1  |[{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}]|
        +---+---------------------------------------+
        <BLANKLINE>

    """
    return df.select(*resolve_nested_fields(fields, starting_level=df))

schema_string(df: DataFrame) -> str

Write the DataFrame's flattened schema to a string.

  • Structs are flattened with a . after their name.
  • Arrays are flattened with a ! character after their name.
  • Maps are flattened with a %key and '%value' after their name.

Limitation: Dots, percents, and exclamation marks are not supported in field names

Given the syntax used, every method defined in the spark_frame.nested module assumes that all field names in DataFrames do not contain any dot ., percent % or exclamation mark !. This can be worked around using the transformation spark_frame.transformations.transform_all_field_names.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required

Returns:

Type Description
str

a string representing the flattened schema

Examples:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT
...     1 as id,
...     ARRAY(STRUCT(2 as a, ARRAY(STRUCT(3 as c, 4 as d)) as b, ARRAY(5, 6) as e)) as s1,
...     STRUCT(7 as f) as s2,
...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s3,
...     ARRAY(ARRAY(STRUCT(1 as e, 2 as f)), ARRAY(STRUCT(3 as e, 4 as f))) as s4
... ''')
>>> df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- s1: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- a: integer (nullable = false)
 |    |    |-- b: array (nullable = false)
 |    |    |    |-- element: struct (containsNull = false)
 |    |    |    |    |-- c: integer (nullable = false)
 |    |    |    |    |-- d: integer (nullable = false)
 |    |    |-- e: array (nullable = false)
 |    |    |    |-- element: integer (containsNull = false)
 |-- s2: struct (nullable = false)
 |    |-- f: integer (nullable = false)
 |-- s3: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: integer (containsNull = false)
 |-- s4: array (nullable = false)
 |    |-- element: array (containsNull = false)
 |    |    |-- element: struct (containsNull = false)
 |    |    |    |-- e: integer (nullable = false)
 |    |    |    |-- f: integer (nullable = false)

>>> print(nested.schema_string(df))
root
 |-- id: integer (nullable = false)
 |-- s1!.a: integer (nullable = false)
 |-- s1!.b!.c: integer (nullable = false)
 |-- s1!.b!.d: integer (nullable = false)
 |-- s1!.e!: integer (nullable = false)
 |-- s2.f: integer (nullable = false)
 |-- s3!!: integer (nullable = false)
 |-- s4!!.e: integer (nullable = false)
 |-- s4!!.f: integer (nullable = false)
Source code in spark_frame/nested_impl/schema_string.py
def schema_string(df: DataFrame) -> str:
    """Write the DataFrame's flattened schema to a string.

    - Structs are flattened with a `.` after their name.
    - Arrays are flattened with a `!` character after their name.
    - Maps are flattened with a `%key` and '%value' after their name.

    !!! warning "Limitation: Dots, percents, and exclamation marks are not supported in field names"
        Given the syntax used, every method defined in the `spark_frame.nested` module assumes that all field
        names in DataFrames do not contain any dot `.`, percent `%` or exclamation mark `!`.
        This can be worked around using the transformation
        [`spark_frame.transformations.transform_all_field_names`]
        [spark_frame.transformations_impl.transform_all_field_names.transform_all_field_names].

    Args:
        df: A Spark DataFrame

    Returns:
        a string representing the flattened schema

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT
        ...     1 as id,
        ...     ARRAY(STRUCT(2 as a, ARRAY(STRUCT(3 as c, 4 as d)) as b, ARRAY(5, 6) as e)) as s1,
        ...     STRUCT(7 as f) as s2,
        ...     ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as s3,
        ...     ARRAY(ARRAY(STRUCT(1 as e, 2 as f)), ARRAY(STRUCT(3 as e, 4 as f))) as s4
        ... ''')
        >>> df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- s1: array (nullable = false)
         |    |-- element: struct (containsNull = false)
         |    |    |-- a: integer (nullable = false)
         |    |    |-- b: array (nullable = false)
         |    |    |    |-- element: struct (containsNull = false)
         |    |    |    |    |-- c: integer (nullable = false)
         |    |    |    |    |-- d: integer (nullable = false)
         |    |    |-- e: array (nullable = false)
         |    |    |    |-- element: integer (containsNull = false)
         |-- s2: struct (nullable = false)
         |    |-- f: integer (nullable = false)
         |-- s3: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: integer (containsNull = false)
         |-- s4: array (nullable = false)
         |    |-- element: array (containsNull = false)
         |    |    |-- element: struct (containsNull = false)
         |    |    |    |-- e: integer (nullable = false)
         |    |    |    |-- f: integer (nullable = false)
        <BLANKLINE>
        >>> print(nested.schema_string(df))
        root
         |-- id: integer (nullable = false)
         |-- s1!.a: integer (nullable = false)
         |-- s1!.b!.c: integer (nullable = false)
         |-- s1!.b!.d: integer (nullable = false)
         |-- s1!.e!: integer (nullable = false)
         |-- s2.f: integer (nullable = false)
         |-- s3!!: integer (nullable = false)
         |-- s4!!.e: integer (nullable = false)
         |-- s4!!.f: integer (nullable = false)
        <BLANKLINE>
    """
    flat_schema = flatten_schema(df.schema, explode=True)
    return _flat_schema_to_tree_string(flat_schema.fields)

unnest_all_fields(df: DataFrame, keep_columns: Optional[List[str]] = None) -> Dict[str, DataFrame]

Given a DataFrame, return a dict of {granularity: DataFrame} where all arrays have been recursively unnested (a.k.a. exploded). This produce one DataFrame for each possible granularity.

For instance, given a DataFrame with the following flattened schema: id s1.a s2!.b s2!.c s2!.s3!.d s4!.e s4!.f

This will produce a dict with four granularity - DataFrames entries
  • '': DataFrame[id, s1.a] ('' corresponds to the root granularity)
  • 's2': DataFrame[s2!.b, s2!.c]
  • 's2!.s3': DataFrame[s2!.s3!.d]
  • 's4': DataFrame[s4!.e, s4!.f]

Limitation: Maps are not unnested

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
keep_columns Optional[List[str]]

Names of columns that should be kept while unnesting

None

Returns:

Type Description
Dict[str, DataFrame]

A list of DataFrames

Examples:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         STRUCT(2 as a) as s1,
...         ARRAY(STRUCT(3 as b, 4 as c, ARRAY(STRUCT(5 as d), STRUCT(6 as d)) as s3)) as s2,
...         ARRAY(STRUCT(7 as e, 8 as f), STRUCT(9 as e, 10 as f)) as s4
... ''')
>>> df.show(truncate=False)
+---+---+--------------------+-----------------+
|id |s1 |s2                  |s4               |
+---+---+--------------------+-----------------+
|1  |{2}|[{3, 4, [{5}, {6}]}]|[{7, 8}, {9, 10}]|
+---+---+--------------------+-----------------+

>>> nested.fields(df)
['id', 's1.a', 's2!.b', 's2!.c', 's2!.s3!.d', 's4!.e', 's4!.f']
>>> result_df_list = nested.unnest_all_fields(df, keep_columns=["id"])
>>> for cols, result_df in result_df_list.items():
...     print(cols)
...     result_df.show()

+---+----+
| id|s1.a|
+---+----+
|  1|   2|
+---+----+

s2!
+---+-----+-----+
| id|s2!.b|s2!.c|
+---+-----+-----+
|  1|    3|    4|
+---+-----+-----+

s2!.s3!
+---+---------+
| id|s2!.s3!.d|
+---+---------+
|  1|        5|
|  1|        6|
+---+---------+

s4!
+---+-----+-----+
| id|s4!.e|s4!.f|
+---+-----+-----+
|  1|    7|    8|
|  1|    9|   10|
+---+-----+-----+
Source code in spark_frame/nested_impl/unnest_all_fields.py
def unnest_all_fields(df: DataFrame, keep_columns: Optional[List[str]] = None) -> Dict[str, DataFrame]:
    """Given a DataFrame, return a dict of {granularity: DataFrame} where all arrays have been recursively
    unnested (a.k.a. exploded).
    This produce one DataFrame for each possible granularity.

    For instance, given a DataFrame with the following flattened schema:
        id
        s1.a
        s2!.b
        s2!.c
        s2!.s3!.d
        s4!.e
        s4!.f

    This will produce a dict with four granularity - DataFrames entries:
        - '': DataFrame[id, s1.a] ('' corresponds to the root granularity)
        - 's2': DataFrame[s2!.b, s2!.c]
        - 's2!.s3': DataFrame[s2!.s3!.d]
        - 's4': DataFrame[s4!.e, s4!.f]

    !!! warning "Limitation: Maps are not unnested"
        - Fields of type Maps are not unnested by this method.
        - A possible workaround is to first use the transformation
        [`spark_frame.transformations.convert_all_maps_to_arrays`]
        [spark_frame.transformations_impl.convert_all_maps_to_arrays.convert_all_maps_to_arrays]

    Args:
        df: A Spark DataFrame
        keep_columns: Names of columns that should be kept while unnesting

    Returns:
        A list of DataFrames

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         STRUCT(2 as a) as s1,
        ...         ARRAY(STRUCT(3 as b, 4 as c, ARRAY(STRUCT(5 as d), STRUCT(6 as d)) as s3)) as s2,
        ...         ARRAY(STRUCT(7 as e, 8 as f), STRUCT(9 as e, 10 as f)) as s4
        ... ''')
        >>> df.show(truncate=False)
        +---+---+--------------------+-----------------+
        |id |s1 |s2                  |s4               |
        +---+---+--------------------+-----------------+
        |1  |{2}|[{3, 4, [{5}, {6}]}]|[{7, 8}, {9, 10}]|
        +---+---+--------------------+-----------------+
        <BLANKLINE>
        >>> nested.fields(df)
        ['id', 's1.a', 's2!.b', 's2!.c', 's2!.s3!.d', 's4!.e', 's4!.f']
        >>> result_df_list = nested.unnest_all_fields(df, keep_columns=["id"])
        >>> for cols, result_df in result_df_list.items():
        ...     print(cols)
        ...     result_df.show()
        <BLANKLINE>
        +---+----+
        | id|s1.a|
        +---+----+
        |  1|   2|
        +---+----+
        <BLANKLINE>
        s2!
        +---+-----+-----+
        | id|s2!.b|s2!.c|
        +---+-----+-----+
        |  1|    3|    4|
        +---+-----+-----+
        <BLANKLINE>
        s2!.s3!
        +---+---------+
        | id|s2!.s3!.d|
        +---+---------+
        |  1|        5|
        |  1|        6|
        +---+---------+
        <BLANKLINE>
        s4!
        +---+-----+-----+
        | id|s4!.e|s4!.f|
        +---+-----+-----+
        |  1|    7|    8|
        |  1|    9|   10|
        +---+-----+-----+
        <BLANKLINE>
    """
    if keep_columns is None:
        keep_columns = []
    fields_to_unnest = [field for field in nested.fields(df) if not is_sub_field_or_equal_to_any(field, keep_columns)]
    return unnest_fields(df, fields_to_unnest, keep_fields=keep_columns)

unnest_field(df: DataFrame, field_name: str, keep_columns: Optional[List[str]] = None) -> DataFrame

Given a DataFrame, return a new DataFrame where the specified column has been recursively unnested (a.k.a. exploded).

Limitation: Maps are not unnested

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
field_name str

The name of a nested column to unnest

required
keep_columns Optional[List[str]]

List of column names to keep while unnesting

None

Returns:

Type Description
DataFrame

A new DataFrame

Examples:

>>> from pyspark.sql import SparkSession
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as arr
... ''')
>>> df.show(truncate=False)
+---+----------------+
|id |arr             |
+---+----------------+
|1  |[[1, 2], [3, 4]]|
+---+----------------+

>>> nested.fields(df)
['id', 'arr!!']
>>> nested.unnest_field(df, 'arr!').show(truncate=False)
+------+
|arr!  |
+------+
|[1, 2]|
|[3, 4]|
+------+

>>> nested.unnest_field(df, 'arr!!').show(truncate=False)
+-----+
|arr!!|
+-----+
|1    |
|2    |
|3    |
|4    |
+-----+

>>> nested.unnest_field(df, 'arr!!', keep_columns=["id"]).show(truncate=False)
+---+-----+
|id |arr!!|
+---+-----+
|1  |1    |
|1  |2    |
|1  |3    |
|1  |4    |
+---+-----+
>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(
...             STRUCT(ARRAY(STRUCT("a1" as a, "b1" as b), STRUCT("a2" as a, "b1" as b)) as s2),
...             STRUCT(ARRAY(STRUCT("a3" as a, "b3" as b)) as s2)
...         ) as s1
... ''')
>>> df.show(truncate=False)
+---+--------------------------------------+
|id |s1                                    |
+---+--------------------------------------+
|1  |[{[{a1, b1}, {a2, b1}]}, {[{a3, b3}]}]|
+---+--------------------------------------+

>>> nested.fields(df)
['id', 's1!.s2!.a', 's1!.s2!.b']
>>> nested.unnest_field(df, 's1!.s2!').show(truncate=False)
+--------+
|s1!.s2! |
+--------+
|{a1, b1}|
|{a2, b1}|
|{a3, b3}|
+--------+
Source code in spark_frame/nested_impl/unnest_field.py
def unnest_field(df: DataFrame, field_name: str, keep_columns: Optional[List[str]] = None) -> DataFrame:
    """Given a DataFrame, return a new DataFrame where the specified column has been recursively
    unnested (a.k.a. exploded).

    !!! warning "Limitation: Maps are not unnested"
        - Fields of type Maps are not unnested by this method.
        - A possible workaround is to first use the transformation
        [`spark_frame.transformations.convert_all_maps_to_arrays`]
        [spark_frame.transformations_impl.convert_all_maps_to_arrays.convert_all_maps_to_arrays]

    Args:
        df: A Spark DataFrame
        field_name: The name of a nested column to unnest
        keep_columns: List of column names to keep while unnesting

    Returns:
        A new DataFrame

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(ARRAY(1, 2), ARRAY(3, 4)) as arr
        ... ''')
        >>> df.show(truncate=False)
        +---+----------------+
        |id |arr             |
        +---+----------------+
        |1  |[[1, 2], [3, 4]]|
        +---+----------------+
        <BLANKLINE>
        >>> nested.fields(df)
        ['id', 'arr!!']
        >>> nested.unnest_field(df, 'arr!').show(truncate=False)
        +------+
        |arr!  |
        +------+
        |[1, 2]|
        |[3, 4]|
        +------+
        <BLANKLINE>
        >>> nested.unnest_field(df, 'arr!!').show(truncate=False)
        +-----+
        |arr!!|
        +-----+
        |1    |
        |2    |
        |3    |
        |4    |
        +-----+
        <BLANKLINE>
        >>> nested.unnest_field(df, 'arr!!', keep_columns=["id"]).show(truncate=False)
        +---+-----+
        |id |arr!!|
        +---+-----+
        |1  |1    |
        |1  |2    |
        |1  |3    |
        |1  |4    |
        +---+-----+
        <BLANKLINE>

        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(
        ...             STRUCT(ARRAY(STRUCT("a1" as a, "b1" as b), STRUCT("a2" as a, "b1" as b)) as s2),
        ...             STRUCT(ARRAY(STRUCT("a3" as a, "b3" as b)) as s2)
        ...         ) as s1
        ... ''')
        >>> df.show(truncate=False)
        +---+--------------------------------------+
        |id |s1                                    |
        +---+--------------------------------------+
        |1  |[{[{a1, b1}, {a2, b1}]}, {[{a3, b3}]}]|
        +---+--------------------------------------+
        <BLANKLINE>
        >>> nested.fields(df)
        ['id', 's1!.s2!.a', 's1!.s2!.b']
        >>> nested.unnest_field(df, 's1!.s2!').show(truncate=False)
        +--------+
        |s1!.s2! |
        +--------+
        |{a1, b1}|
        |{a2, b1}|
        |{a3, b3}|
        +--------+
        <BLANKLINE>

    """
    if keep_columns is None:
        keep_columns = []
    return next(iter(unnest_fields(df, field_name, keep_fields=keep_columns).values()))

with_fields(df: DataFrame, fields: Mapping[str, AnyKindOfTransformation]) -> DataFrame

Return a new DataFrame by adding or replacing (when they already exist) columns.

This method is similar to the DataFrame.withColumn method, with the extra capability of working on nested and repeated fields (structs and arrays).

The syntax for field names works as follows:

  • "." is the separator for struct elements
  • "!" must be appended at the end of fields that are repeated (arrays)
  • Map keys are appended with %key
  • Map values are appended with %value

The following types of transformation are allowed:

  • String and column expressions can be used on any non-repeated field, even nested ones.
  • When working on repeated fields, transformations must be expressed as higher order functions (e.g. lambda expressions). String and column expressions can be used on repeated fields as well, but their value will be repeated multiple times.
  • When working on multiple levels of nested arrays, higher order functions may take multiple arguments, corresponding to each level of repetition (See Example 5.).
  • None can also be used to represent the identity transformation, this is useful to select a field without changing and without having to repeat its name.

Limitation: Dots, percents, and exclamation marks are not supported in field names

Given the syntax used, every method defined in the spark_frame.nested module assumes that all field names in DataFrames do not contain any dot ., percent % or exclamation mark !. This can be worked around using the transformation spark_frame.transformations.transform_all_field_names.

Parameters:

Name Type Description Default
df DataFrame

A Spark DataFrame

required
fields Mapping[str, AnyKindOfTransformation]

A Dict(field_name, transformation_to_apply)

required

Returns:

Type Description
DataFrame

A new DataFrame with the same fields as the input DataFrame, where the specified transformations have been

DataFrame

applied to the corresponding fields. If a field name did not exist in the input DataFrame,

DataFrame

it will be added to the output DataFrame. If it did exist, the original value will be replaced with the new one.

Example 1: non-repeated fields

>>> from pyspark.sql import SparkSession
>>> from pyspark.sql import functions as f
>>> from spark_frame import nested
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()
>>> df = spark.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s.a: integer (nullable = false)
 |-- s.b: integer (nullable = false)

>>> df.show()
+---+------+
| id|     s|
+---+------+
|  1|{2, 3}|
+---+------+

Transformations on non-repeated fields may be expressed as a string representing a column name or a Column expression.

>>> new_df = nested.with_fields(df, {
...     "s.id": "id",                                 # column name (string)
...     "s.c": f.col("s.a") + f.col("s.b")            # Column expression
... })
>>> new_df.printSchema()
root
 |-- id: integer (nullable = false)
 |-- s: struct (nullable = false)
 |    |-- a: integer (nullable = false)
 |    |-- b: integer (nullable = false)
 |    |-- id: integer (nullable = false)
 |    |-- c: integer (nullable = false)

>>> new_df.show()
+---+------------+
| id|           s|
+---+------------+
|  1|{2, 3, 1, 5}|
+---+------------+

Example 2: repeated fields

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(STRUCT(1 as a, STRUCT(2 as c) as b), STRUCT(3 as a, STRUCT(4 as c) as b)) as s
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s!.a: integer (nullable = false)
 |-- s!.b.c: integer (nullable = false)

>>> df.show()
+---+--------------------+
| id|                   s|
+---+--------------------+
|  1|[{1, {2}}, {3, {4}}]|
+---+--------------------+

Transformations on repeated fields must be expressed as higher-order functions (lambda expressions or named functions). The value passed to this function will correspond to the last repeated element.

>>> new_df = df.transform(nested.with_fields, {
...     "s!.b.d": lambda s: s["a"] + s["b"]["c"]}
... )
>>> nested.print_schema(new_df)
root
 |-- id: integer (nullable = false)
 |-- s!.a: integer (nullable = false)
 |-- s!.b.c: integer (nullable = false)
 |-- s!.b.d: integer (nullable = false)

>>> new_df.show(truncate=False)
+---+--------------------------+
|id |s                         |
+---+--------------------------+
|1  |[{1, {2, 3}}, {3, {4, 7}}]|
+---+--------------------------+

String and column expressions can be used on repeated fields as well, but their value will be repeated multiple times.

>>> df.transform(nested.with_fields, {
...     "id": None,
...     "s!.a": "id",
...     "s!.b.c": f.lit(2)
... }).show(truncate=False)
+---+--------------------+
|id |s                   |
+---+--------------------+
|1  |[{1, {2}}, {1, {2}}]|
+---+--------------------+

Example 3: field repeated twice

>>> df = spark.sql('SELECT 1 as id, ARRAY(STRUCT(ARRAY(1, 2, 3) as e)) as s')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s!.e!: integer (nullable = false)

>>> df.show()
+---+-------------+
| id|            s|
+---+-------------+
|  1|[{[1, 2, 3]}]|
+---+-------------+

Here, the lambda expression will be applied to the last repeated element e.

>>> df.transform(nested.with_fields, {"s!.e!": lambda e : e.cast("DOUBLE")}).show()
+---+-------------------+
| id|                  s|
+---+-------------------+
|  1|[{[1.0, 2.0, 3.0]}]|
+---+-------------------+

Example 4: Dataframe with maps

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         MAP("a", STRUCT(2 as a, 3 as b)) as m1
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- m1%key: string (nullable = false)
 |-- m1%value.a: integer (nullable = false)
 |-- m1%value.b: integer (nullable = false)

>>> df.show()
+---+-------------+
| id|           m1|
+---+-------------+
|  1|{a -> {2, 3}}|
+---+-------------+
>>> new_df = df.transform(nested.with_fields, {
...  "m1%key": lambda key : f.upper(key),
...  "m1%value.a": lambda value : value["a"].cast("DOUBLE")
... })
>>> nested.print_schema(new_df)
root
 |-- id: integer (nullable = false)
 |-- m1%key: string (nullable = false)
 |-- m1%value.a: double (nullable = false)
 |-- m1%value.b: integer (nullable = false)

>>> new_df.show()
+---+---------------+
| id|             m1|
+---+---------------+
|  1|{A -> {2.0, 3}}|
+---+---------------+

Example 5: Accessing multiple repetition levels

>>> df = spark.sql('''
...     SELECT
...         1 as id,
...         ARRAY(
...             STRUCT(2 as average, ARRAY(1, 2, 3) as values),
...             STRUCT(3 as average, ARRAY(1, 2, 3, 4, 5) as values)
...         ) as s1
... ''')
>>> nested.print_schema(df)
root
 |-- id: integer (nullable = false)
 |-- s1!.average: integer (nullable = false)
 |-- s1!.values!: integer (nullable = false)

>>> df.show(truncate=False)
+---+--------------------------------------+
|id |s1                                    |
+---+--------------------------------------+
|1  |[{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}]|
+---+--------------------------------------+

Here, the transformation applied to "s1!.values!" takes two arguments.

>>> new_df = df.transform(nested.with_fields, {
...  "s1!.values!": lambda s1, value : value - s1["average"]
... })
>>> new_df.show(truncate=False)
+---+-----------------------------------------+
|id |s1                                       |
+---+-----------------------------------------+
|1  |[{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}]|
+---+-----------------------------------------+

Extra arguments can be added to the left for each repetition level, up to the root level.

>>> new_df = df.transform(nested.with_fields, {
...  "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"]
... })
>>> new_df.show(truncate=False)
+---+---------------------------------------+
|id |s1                                     |
+---+---------------------------------------+
|1  |[{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}]|
+---+---------------------------------------+
Source code in spark_frame/nested_impl/with_fields.py
def with_fields(df: DataFrame, fields: Mapping[str, AnyKindOfTransformation]) -> DataFrame:
    """Return a new [DataFrame][pyspark.sql.DataFrame] by adding or replacing (when they already exist) columns.

    This method is similar to the [DataFrame.withColumn][pyspark.sql.DataFrame.withColumn] method, with the extra
    capability of working on nested and repeated fields (structs and arrays).

    The syntax for field names works as follows:

    - "." is the separator for struct elements
    - "!" must be appended at the end of fields that are repeated (arrays)
    - Map keys are appended with `%key`
    - Map values are appended with `%value`

    The following types of transformation are allowed:

    - String and column expressions can be used on any non-repeated field, even nested ones.
    - When working on repeated fields, transformations must be expressed as higher order functions
      (e.g. lambda expressions). String and column expressions can be used on repeated fields as well,
      but their value will be repeated multiple times.
    - When working on multiple levels of nested arrays, higher order functions may take multiple arguments,
      corresponding to each level of repetition (See Example 5.).
    - `None` can also be used to represent the identity transformation, this is useful to select a field without
       changing and without having to repeat its name.

    !!! warning "Limitation: Dots, percents, and exclamation marks are not supported in field names"
        Given the syntax used, every method defined in the `spark_frame.nested` module assumes that all field
        names in DataFrames do not contain any dot `.`, percent `%` or exclamation mark `!`.
        This can be worked around using the transformation
        [`spark_frame.transformations.transform_all_field_names`]
        [spark_frame.transformations_impl.transform_all_field_names.transform_all_field_names].

    Args:
        df: A Spark DataFrame
        fields: A Dict(field_name, transformation_to_apply)

    Returns:
        A new DataFrame with the same fields as the input DataFrame, where the specified transformations have been
        applied to the corresponding fields. If a field name did not exist in the input DataFrame,
        it will be added to the output DataFrame. If it did exist, the original value will be replaced with the new one.

    Examples: Example 1: non-repeated fields
        >>> from pyspark.sql import SparkSession
        >>> from pyspark.sql import functions as f
        >>> from spark_frame import nested
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()
        >>> df = spark.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s.a: integer (nullable = false)
         |-- s.b: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+------+
        | id|     s|
        +---+------+
        |  1|{2, 3}|
        +---+------+
        <BLANKLINE>

        Transformations on non-repeated fields may be expressed as a string representing a column name
        or a Column expression.
        >>> new_df = nested.with_fields(df, {
        ...     "s.id": "id",                                 # column name (string)
        ...     "s.c": f.col("s.a") + f.col("s.b")            # Column expression
        ... })
        >>> new_df.printSchema()
        root
         |-- id: integer (nullable = false)
         |-- s: struct (nullable = false)
         |    |-- a: integer (nullable = false)
         |    |-- b: integer (nullable = false)
         |    |-- id: integer (nullable = false)
         |    |-- c: integer (nullable = false)
        <BLANKLINE>
        >>> new_df.show()
        +---+------------+
        | id|           s|
        +---+------------+
        |  1|{2, 3, 1, 5}|
        +---+------------+
        <BLANKLINE>

    Examples: Example 2: repeated fields
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(STRUCT(1 as a, STRUCT(2 as c) as b), STRUCT(3 as a, STRUCT(4 as c) as b)) as s
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s!.a: integer (nullable = false)
         |-- s!.b.c: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+--------------------+
        | id|                   s|
        +---+--------------------+
        |  1|[{1, {2}}, {3, {4}}]|
        +---+--------------------+
        <BLANKLINE>

        Transformations on repeated fields must be expressed as
        higher-order functions (lambda expressions or named functions).
        The value passed to this function will correspond to the last repeated element.
        >>> new_df = df.transform(nested.with_fields, {
        ...     "s!.b.d": lambda s: s["a"] + s["b"]["c"]}
        ... )
        >>> nested.print_schema(new_df)
        root
         |-- id: integer (nullable = false)
         |-- s!.a: integer (nullable = false)
         |-- s!.b.c: integer (nullable = false)
         |-- s!.b.d: integer (nullable = false)
        <BLANKLINE>
        >>> new_df.show(truncate=False)
        +---+--------------------------+
        |id |s                         |
        +---+--------------------------+
        |1  |[{1, {2, 3}}, {3, {4, 7}}]|
        +---+--------------------------+
        <BLANKLINE>

        String and column expressions can be used on repeated fields as well,
        but their value will be repeated multiple times.
        >>> df.transform(nested.with_fields, {
        ...     "id": None,
        ...     "s!.a": "id",
        ...     "s!.b.c": f.lit(2)
        ... }).show(truncate=False)
        +---+--------------------+
        |id |s                   |
        +---+--------------------+
        |1  |[{1, {2}}, {1, {2}}]|
        +---+--------------------+
        <BLANKLINE>

    Examples: Example 3: field repeated twice
        >>> df = spark.sql('SELECT 1 as id, ARRAY(STRUCT(ARRAY(1, 2, 3) as e)) as s')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s!.e!: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+-------------+
        | id|            s|
        +---+-------------+
        |  1|[{[1, 2, 3]}]|
        +---+-------------+
        <BLANKLINE>

        Here, the lambda expression will be applied to the last repeated element `e`.
        >>> df.transform(nested.with_fields, {"s!.e!": lambda e : e.cast("DOUBLE")}).show()
        +---+-------------------+
        | id|                  s|
        +---+-------------------+
        |  1|[{[1.0, 2.0, 3.0]}]|
        +---+-------------------+
        <BLANKLINE>

    Examples: Example 4: Dataframe with maps
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         MAP("a", STRUCT(2 as a, 3 as b)) as m1
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- m1%key: string (nullable = false)
         |-- m1%value.a: integer (nullable = false)
         |-- m1%value.b: integer (nullable = false)
        <BLANKLINE>
        >>> df.show()
        +---+-------------+
        | id|           m1|
        +---+-------------+
        |  1|{a -> {2, 3}}|
        +---+-------------+
        <BLANKLINE>

        >>> new_df = df.transform(nested.with_fields, {
        ...  "m1%key": lambda key : f.upper(key),
        ...  "m1%value.a": lambda value : value["a"].cast("DOUBLE")
        ... })
        >>> nested.print_schema(new_df)
        root
         |-- id: integer (nullable = false)
         |-- m1%key: string (nullable = false)
         |-- m1%value.a: double (nullable = false)
         |-- m1%value.b: integer (nullable = false)
        <BLANKLINE>
        >>> new_df.show()
        +---+---------------+
        | id|             m1|
        +---+---------------+
        |  1|{A -> {2.0, 3}}|
        +---+---------------+
        <BLANKLINE>

    Examples: Example 5: Accessing multiple repetition levels
        >>> df = spark.sql('''
        ...     SELECT
        ...         1 as id,
        ...         ARRAY(
        ...             STRUCT(2 as average, ARRAY(1, 2, 3) as values),
        ...             STRUCT(3 as average, ARRAY(1, 2, 3, 4, 5) as values)
        ...         ) as s1
        ... ''')
        >>> nested.print_schema(df)
        root
         |-- id: integer (nullable = false)
         |-- s1!.average: integer (nullable = false)
         |-- s1!.values!: integer (nullable = false)
        <BLANKLINE>
        >>> df.show(truncate=False)
        +---+--------------------------------------+
        |id |s1                                    |
        +---+--------------------------------------+
        |1  |[{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}]|
        +---+--------------------------------------+
        <BLANKLINE>

        Here, the transformation applied to "s1!.values!" takes two arguments.
        >>> new_df = df.transform(nested.with_fields, {
        ...  "s1!.values!": lambda s1, value : value - s1["average"]
        ... })
        >>> new_df.show(truncate=False)
        +---+-----------------------------------------+
        |id |s1                                       |
        +---+-----------------------------------------+
        |1  |[{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}]|
        +---+-----------------------------------------+
        <BLANKLINE>

        Extra arguments can be added to the left for each repetition level, up to the root level.
        >>> new_df = df.transform(nested.with_fields, {
        ...  "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"]
        ... })
        >>> new_df.show(truncate=False)
        +---+---------------------------------------+
        |id |s1                                     |
        +---+---------------------------------------+
        |1  |[{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}]|
        +---+---------------------------------------+
        <BLANKLINE>

    """
    default_columns = {field: None for field in nested.fields(df)}
    fields = {**default_columns, **fields}
    return df.select(*resolve_nested_fields(fields, starting_level=df))