Skip to content


This module contains implementations of graph algorithms and related methods.

ascending_forest_traversal(input_df: DataFrame, node_id: str, parent_id: str, keep_labels: bool = False) -> DataFrame

Given a DataFrame representing a labeled forest with columns id, parent_id and other label columns, performs a graph traversal that will return a DataFrame with the same schema that gives for each node the labels of it's furthest ancestor.

In the input DataFrame, a node is considered to have no parent if its parent_id is null or equal to its node_id. In the output DataFrame, a node that has no parent will have its parent_id equal to its node_id. Cycle protection: If the graph contains any cycle, the nodes in that cycle will have a NULL parent_id.

It has a security against dependency cycles, but no security preventing a combinatorial explosion if some nodes have more than one parent.


Name Type Description Default
input_df DataFrame

A Spark DataFrame

node_id str

Name of the column that represent the node's ids

parent_id str

Name of the column that represent the parent node's ids

keep_labels bool

If set to true, add two structs column called "node" and "furthest_ancestor" containing the content of the row from the input DataFrame for the corresponding nodes and their furthest ancestor



Type Description

A DataFrame with two columns named according to node_id and parent_id that gives for each node


the id of it's furthest ancestor (in the parent_id column).


If the option keep_labels is used, two extra columns of type STRUCT are a added to the output DataFrame,


they represent the content of the rows in the input DataFrame corresponding to the node and its furthest


ancestor, respectively.


>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()

Given a DataFrame with pokemon attributes and evolution links

>>> input_df = spark.sql('''
...     SELECT
...       col1 as ``,
...       col2 as `pokemon.evolve_to_id`,
...       col3 as ``,
...       col4 as `pokemon.types`
...       (4, 5, 'Charmander', ARRAY('Fire')),
...       (5, 6, 'Charmeleon', ARRAY('Fire')),
...       (6, NULL, 'Charizard', ARRAY('Fire', 'Flying'))
... ''')
||pokemon.evolve_to_id|| pokemon.types|
|         4|                   5|  Charmander|        [Fire]|
|         5|                   6|  Charmeleon|        [Fire]|
|         6|                NULL|   Charizard|[Fire, Flying]|

We compute a DataFrame that for each gives the attributes of its highest level of evolution

>>> ascending_forest_traversal(input_df, "", "pokemon.evolve_to_id").orderBy("``").show()
|         4|                   6|
|         5|                   6|
|         6|                   6|

With the keep_label option extra joins are performed at the end of the algorithm to add two struct columns containing the corresponding row for the original node and the furthest ancestor.

>>> ascending_forest_traversal(input_df, "", "pokemon.evolve_to_id", keep_labels=True
...     ).orderBy("``").show(10, False)
||pokemon.evolve_to_id|node                                |furthest_ancestor                   |
|4         |6                   |{4, 5, Charmander, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
|5         |6                   |{5, 6, Charmeleon, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
|6         |6                   |{6, NULL, Charizard, [Fire, Flying]}|{6, NULL, Charizard, [Fire, Flying]}|

Cycle Protection: to prevent the algorithm from looping indefinitely, cycles are detected, and the nodes that are part of cycles will end up with a NULL value as their furthest ancestor

>>> input_df = spark.sql('''
...     SELECT
...       col1 as `node_id`,
...       col2 as `parent_id`
...     FROM VALUES (1, 2), (2, 3), (3, 1)
... ''')
|      1|        2|
|      2|        3|
|      3|        1|

>>> ascending_forest_traversal(input_df, "node_id", "parent_id").orderBy("node_id").show()
|      1|     NULL|
|      2|     NULL|
|      3|     NULL|
Source code in spark_frame/graph_impl/
def ascending_forest_traversal(
    input_df: DataFrame,
    node_id: str,
    parent_id: str,
    keep_labels: bool = False,
) -> DataFrame:
    """Given a DataFrame representing a labeled forest with columns `id`, `parent_id` and other label columns,
    performs a graph traversal that will return a DataFrame with the same schema that gives for each node
    the labels of it's furthest ancestor.

    In the input DataFrame, a node is considered to have no parent if its parent_id is null or equal to its node_id.
    In the output DataFrame, a node that has no parent will have its parent_id equal to its node_id.
    Cycle protection: If the graph contains any cycle, the nodes in that cycle will have a NULL parent_id.

    It has a security against dependency cycles, but no security preventing
    a combinatorial explosion if some nodes have more than one parent.

        input_df: A Spark DataFrame
        node_id: Name of the column that represent the node's ids
        parent_id: Name of the column that represent the parent node's ids
        keep_labels: If set to true, add two structs column called "node" and "furthest_ancestor" containing
            the content of the row from the input DataFrame for the corresponding nodes and their furthest ancestor

        A DataFrame with two columns named according to `node_id` and `parent_id` that gives for each node
        the id of it's furthest ancestor (in the `parent_id` column).
        If the option `keep_labels` is used, two extra columns of type STRUCT are a added to the output DataFrame,
        they represent the content of the rows in the input DataFrame corresponding to the node and its furthest
        ancestor, respectively.

        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()

        Given a DataFrame with pokemon attributes and evolution links

        >>> input_df = spark.sql('''
        ...     SELECT
        ...       col1 as ``,
        ...       col2 as `pokemon.evolve_to_id`,
        ...       col3 as ``,
        ...       col4 as `pokemon.types`
        ...     FROM VALUES
        ...       (4, 5, 'Charmander', ARRAY('Fire')),
        ...       (5, 6, 'Charmeleon', ARRAY('Fire')),
        ...       (6, NULL, 'Charizard', ARRAY('Fire', 'Flying'))
        ... ''')
        ||pokemon.evolve_to_id|| pokemon.types|
        |         4|                   5|  Charmander|        [Fire]|
        |         5|                   6|  Charmeleon|        [Fire]|
        |         6|                NULL|   Charizard|[Fire, Flying]|

        We compute a DataFrame that for each gives the attributes of its highest level of evolution

        >>> ascending_forest_traversal(input_df, "", "pokemon.evolve_to_id").orderBy("``").show()
        |         4|                   6|
        |         5|                   6|
        |         6|                   6|

        With the `keep_label` option extra joins are performed at the end of the algorithm to add two struct columns
        containing the corresponding row for the original node and the furthest ancestor.

        >>> ascending_forest_traversal(input_df, "", "pokemon.evolve_to_id", keep_labels=True
        ...     ).orderBy("``").show(10, False)
        ||pokemon.evolve_to_id|node                                |furthest_ancestor                   |
        |4         |6                   |{4, 5, Charmander, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
        |5         |6                   |{5, 6, Charmeleon, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
        |6         |6                   |{6, NULL, Charizard, [Fire, Flying]}|{6, NULL, Charizard, [Fire, Flying]}|

        *Cycle Protection:* to prevent the algorithm from looping indefinitely, cycles are detected, and the nodes
        that are part of cycles will end up with a NULL value as their furthest ancestor

        >>> input_df = spark.sql('''
        ...     SELECT
        ...       col1 as `node_id`,
        ...       col2 as `parent_id`
        ...     FROM VALUES (1, 2), (2, 3), (3, 1)
        ... ''')
        |      1|        2|
        |      2|        3|
        |      3|        1|
        >>> ascending_forest_traversal(input_df, "node_id", "parent_id").orderBy("node_id").show()
        |      1|     NULL|
        |      2|     NULL|
        |      3|     NULL|
        node_id in input_df.columns,
        "Could not find column %s in Dataframe's columns: %s" % (node_id, input_df.columns),
        parent_id in input_df.columns,
        "Could not find column %s in Dataframe's columns: %s" % (parent_id, input_df.columns),
    node_id_col_name = "node_id"
    parent_id_col_name = "parent_id"
    status_col_name = "status"
    df =

    res_df = _ascending_forest_traversal(
    res_df =

    if keep_labels:
        res_df = res_df.join(input_df, node_id).select(
            f.struct(*[input_df[quote(col)] for col in input_df.columns]).alias("node"),
        res_df = res_df.join(input_df, res_df[quote(parent_id)] == input_df[quote(node_id)]).select(
            f.struct(*[input_df[quote(col)] for col in input_df.columns]).alias("furthest_ancestor"),

    return res_df