Skip to content

spark_frame.graph

This module contains implementations of graph algorithms and related methods.


ascending_forest_traversal(input_df: DataFrame, node_id: str, parent_id: str, keep_labels: bool = False) -> DataFrame

Given a DataFrame representing a labeled forest with columns id, parent_id and other label columns, performs a graph traversal that will return a DataFrame with the same schema that gives for each node the labels of it's furthest ancestor.

In the input DataFrame, a node is considered to have no parent if its parent_id is null or equal to its node_id. In the output DataFrame, a node that has no parent will have its parent_id equal to its node_id. Cycle protection: If the graph contains any cycle, the nodes in that cycle will have a NULL parent_id.

It has a security against dependency cycles, but no security preventing a combinatorial explosion if some nodes have more than one parent.

Parameters:

Name Type Description Default
input_df DataFrame

A Spark DataFrame

required
node_id str

Name of the column that represent the node's ids

required
parent_id str

Name of the column that represent the parent node's ids

required
keep_labels bool

If set to true, add two structs column called "node" and "furthest_ancestor" containing the content of the row from the input DataFrame for the corresponding nodes and their furthest ancestor

False

Returns:

Type Description
DataFrame

A DataFrame with two columns named according to node_id and parent_id that gives for each node

DataFrame

the id of it's furthest ancestor (in the parent_id column).

DataFrame

If the option keep_labels is used, two extra columns of type STRUCT are a added to the output DataFrame,

DataFrame

they represent the content of the rows in the input DataFrame corresponding to the node and its furthest

DataFrame

ancestor, respectively.

Examples:

>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.appName("doctest").getOrCreate()

Given a DataFrame with pokemon attributes and evolution links

>>> input_df = spark.sql('''
...     SELECT
...       col1 as `pokemon.id`,
...       col2 as `pokemon.evolve_to_id`,
...       col3 as `pokemon.name`,
...       col4 as `pokemon.types`
...     FROM VALUES
...       (4, 5, 'Charmander', ARRAY('Fire')),
...       (5, 6, 'Charmeleon', ARRAY('Fire')),
...       (6, NULL, 'Charizard', ARRAY('Fire', 'Flying'))
... ''')
>>> input_df.show()
+----------+--------------------+------------+--------------+
|pokemon.id|pokemon.evolve_to_id|pokemon.name| pokemon.types|
+----------+--------------------+------------+--------------+
|         4|                   5|  Charmander|        [Fire]|
|         5|                   6|  Charmeleon|        [Fire]|
|         6|                NULL|   Charizard|[Fire, Flying]|
+----------+--------------------+------------+--------------+

We compute a DataFrame that for each pokemon.id gives the attributes of its highest level of evolution

>>> ascending_forest_traversal(input_df, "pokemon.id", "pokemon.evolve_to_id").orderBy("`pokemon.id`").show()
+----------+--------------------+
|pokemon.id|pokemon.evolve_to_id|
+----------+--------------------+
|         4|                   6|
|         5|                   6|
|         6|                   6|
+----------+--------------------+

With the keep_label option extra joins are performed at the end of the algorithm to add two struct columns containing the corresponding row for the original node and the furthest ancestor.

>>> ascending_forest_traversal(input_df, "pokemon.id", "pokemon.evolve_to_id", keep_labels=True
...     ).orderBy("`pokemon.id`").show(10, False)
+----------+--------------------+------------------------------------+------------------------------------+
|pokemon.id|pokemon.evolve_to_id|node                                |furthest_ancestor                   |
+----------+--------------------+------------------------------------+------------------------------------+
|4         |6                   |{4, 5, Charmander, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
|5         |6                   |{5, 6, Charmeleon, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
|6         |6                   |{6, NULL, Charizard, [Fire, Flying]}|{6, NULL, Charizard, [Fire, Flying]}|
+----------+--------------------+------------------------------------+------------------------------------+

Cycle Protection: to prevent the algorithm from looping indefinitely, cycles are detected, and the nodes that are part of cycles will end up with a NULL value as their furthest ancestor

>>> input_df = spark.sql('''
...     SELECT
...       col1 as `node_id`,
...       col2 as `parent_id`
...     FROM VALUES (1, 2), (2, 3), (3, 1)
... ''')
>>> input_df.show()
+-------+---------+
|node_id|parent_id|
+-------+---------+
|      1|        2|
|      2|        3|
|      3|        1|
+-------+---------+

>>> ascending_forest_traversal(input_df, "node_id", "parent_id").orderBy("node_id").show()
+-------+---------+
|node_id|parent_id|
+-------+---------+
|      1|     NULL|
|      2|     NULL|
|      3|     NULL|
+-------+---------+
Source code in spark_frame/graph_impl/ascending_forest_traversal.py
def ascending_forest_traversal(
    input_df: DataFrame,
    node_id: str,
    parent_id: str,
    keep_labels: bool = False,
) -> DataFrame:
    """Given a DataFrame representing a labeled forest with columns `id`, `parent_id` and other label columns,
    performs a graph traversal that will return a DataFrame with the same schema that gives for each node
    the labels of it's furthest ancestor.

    In the input DataFrame, a node is considered to have no parent if its parent_id is null or equal to its node_id.
    In the output DataFrame, a node that has no parent will have its parent_id equal to its node_id.
    Cycle protection: If the graph contains any cycle, the nodes in that cycle will have a NULL parent_id.

    It has a security against dependency cycles, but no security preventing
    a combinatorial explosion if some nodes have more than one parent.

    Args:
        input_df: A Spark DataFrame
        node_id: Name of the column that represent the node's ids
        parent_id: Name of the column that represent the parent node's ids
        keep_labels: If set to true, add two structs column called "node" and "furthest_ancestor" containing
            the content of the row from the input DataFrame for the corresponding nodes and their furthest ancestor

    Returns:
        A DataFrame with two columns named according to `node_id` and `parent_id` that gives for each node
        the id of it's furthest ancestor (in the `parent_id` column).
        If the option `keep_labels` is used, two extra columns of type STRUCT are a added to the output DataFrame,
        they represent the content of the rows in the input DataFrame corresponding to the node and its furthest
        ancestor, respectively.

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("doctest").getOrCreate()

        Given a DataFrame with pokemon attributes and evolution links

        >>> input_df = spark.sql('''
        ...     SELECT
        ...       col1 as `pokemon.id`,
        ...       col2 as `pokemon.evolve_to_id`,
        ...       col3 as `pokemon.name`,
        ...       col4 as `pokemon.types`
        ...     FROM VALUES
        ...       (4, 5, 'Charmander', ARRAY('Fire')),
        ...       (5, 6, 'Charmeleon', ARRAY('Fire')),
        ...       (6, NULL, 'Charizard', ARRAY('Fire', 'Flying'))
        ... ''')
        >>> input_df.show()
        +----------+--------------------+------------+--------------+
        |pokemon.id|pokemon.evolve_to_id|pokemon.name| pokemon.types|
        +----------+--------------------+------------+--------------+
        |         4|                   5|  Charmander|        [Fire]|
        |         5|                   6|  Charmeleon|        [Fire]|
        |         6|                NULL|   Charizard|[Fire, Flying]|
        +----------+--------------------+------------+--------------+
        <BLANKLINE>

        We compute a DataFrame that for each pokemon.id gives the attributes of its highest level of evolution

        >>> ascending_forest_traversal(input_df, "pokemon.id", "pokemon.evolve_to_id").orderBy("`pokemon.id`").show()
        +----------+--------------------+
        |pokemon.id|pokemon.evolve_to_id|
        +----------+--------------------+
        |         4|                   6|
        |         5|                   6|
        |         6|                   6|
        +----------+--------------------+
        <BLANKLINE>

        With the `keep_label` option extra joins are performed at the end of the algorithm to add two struct columns
        containing the corresponding row for the original node and the furthest ancestor.

        >>> ascending_forest_traversal(input_df, "pokemon.id", "pokemon.evolve_to_id", keep_labels=True
        ...     ).orderBy("`pokemon.id`").show(10, False)
        +----------+--------------------+------------------------------------+------------------------------------+
        |pokemon.id|pokemon.evolve_to_id|node                                |furthest_ancestor                   |
        +----------+--------------------+------------------------------------+------------------------------------+
        |4         |6                   |{4, 5, Charmander, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
        |5         |6                   |{5, 6, Charmeleon, [Fire]}          |{6, NULL, Charizard, [Fire, Flying]}|
        |6         |6                   |{6, NULL, Charizard, [Fire, Flying]}|{6, NULL, Charizard, [Fire, Flying]}|
        +----------+--------------------+------------------------------------+------------------------------------+
        <BLANKLINE>

        *Cycle Protection:* to prevent the algorithm from looping indefinitely, cycles are detected, and the nodes
        that are part of cycles will end up with a NULL value as their furthest ancestor

        >>> input_df = spark.sql('''
        ...     SELECT
        ...       col1 as `node_id`,
        ...       col2 as `parent_id`
        ...     FROM VALUES (1, 2), (2, 3), (3, 1)
        ... ''')
        >>> input_df.show()
        +-------+---------+
        |node_id|parent_id|
        +-------+---------+
        |      1|        2|
        |      2|        3|
        |      3|        1|
        +-------+---------+
        <BLANKLINE>
        >>> ascending_forest_traversal(input_df, "node_id", "parent_id").orderBy("node_id").show()
        +-------+---------+
        |node_id|parent_id|
        +-------+---------+
        |      1|     NULL|
        |      2|     NULL|
        |      3|     NULL|
        +-------+---------+
        <BLANKLINE>
    """
    assert_true(
        node_id in input_df.columns,
        "Could not find column %s in Dataframe's columns: %s" % (node_id, input_df.columns),
    )
    assert_true(
        parent_id in input_df.columns,
        "Could not find column %s in Dataframe's columns: %s" % (parent_id, input_df.columns),
    )
    node_id_col_name = "node_id"
    parent_id_col_name = "parent_id"
    status_col_name = "status"
    df = input_df.select(
        f.col(quote(node_id)).alias(node_id_col_name),
        f.col(quote(parent_id)).alias(parent_id_col_name),
    )

    res_df = _ascending_forest_traversal(
        df,
        node_id_col_name=node_id_col_name,
        parent_id_col_name=parent_id_col_name,
        status_col_name=status_col_name,
    )
    res_df = res_df.select(
        f.col(node_id_col_name).alias(node_id),
        f.col(parent_id_col_name).alias(parent_id),
    )

    if keep_labels:
        res_df = res_df.join(input_df, node_id).select(
            res_df["*"],
            f.struct(*[input_df[quote(col)] for col in input_df.columns]).alias("node"),
        )
        res_df = res_df.join(input_df, res_df[quote(parent_id)] == input_df[quote(node_id)]).select(
            res_df[quote(node_id)],
            res_df[quote(parent_id)],
            res_df["node"],
            f.struct(*[input_df[quote(col)] for col in input_df.columns]).alias("furthest_ancestor"),
        )

    return res_df