This is part of series of posts about associative grouping:

In the first two parts of this series we looked at how we could use recursive CTE’s and SQL Server’s graph functionality to find overlapping groups in two columns in a table, in order to put them into a new super group of associated groups.

Since then I’ve been doing most of my work in Azure, and doing quite a lot with Databricks/Spark. I recently responded to a tweet about trying to do what we had done in the first tow posts. It transpired that they were using Synapse SQL pools, so neither of the previous methods were supported.

This got me thinking, if you’re working with data lakes and/or Azure Synapse SQL Pools, how would you go about doing this? We can use the GraphFrames library in Databricks to process the data and find the groups for us.

Installing the Library

First we need to download the GraphFrames library. I’m using a Databricks cluster with spark 3.1 / Scala 2.12, so I grabbed the latest jar file matching that.

Once you have the jar file downloaded, we need to add it to our Databricks cluster. In Databricks, navigate to the Compute tab and select the cluster you want to install the library onto. Make sure it’s up and running too.

Go to the Libraries tab, and click add. Drag your downloaded jar file into the popup window and click Install.

Python Notebook Code

Now the library is installed, we can write some code. We’ll create a DataFrame from an array to give us the same data as in the previous posts. As a refresher, our data looks like this:

IdSupplier NameTax NumberBank Sort CodeBank Account NumberRequired Output
1AdventureWorks Ltd.1234567802-11-33123456781
3ADVENTURE WORKS2334455602-55-44151617181
4DVENTURE WORKS LTD.2334455602-77-66998877661
5AW Bike Co2334455602-88-00119911991
6Big Bike Discounts5555666602-88-00119911991
7Contoso Corp9000900002-99-02123412342

This produces the same simple disconnected graphs as last time. We have one graph with 6 nodes, and another graph with a single node.

The Nodes are numbered using the Id’s of the rows

And here’s the python to create a DataFrame with the data above, and then show us the schema and the data. Add this to the first cell in your notebook.

from pyspark.sql.functions import *
from pyspark.sql.types import StructType,StructField, StringType, IntegerType
from pyspark.sql.window import Window

# Create the data as an array
data = [
  ("AdventureWorks Ltd.",      12345678, "02-11-33", 12345678),
  ("AdventureWorks",           12345678, "02-55-44", 15161718),
  ("ADVENTURE WORKS",          23344556, "02-55-44", 15161718),
  #("ADVENTURE WORKS BIKES",    23344556, "02-55-44", 15161718),
  ("ADVENTURE WORKS LTD.",     23344556, "02-77-66", 99887766),
  ("AW Bike Co",               23344556, "02-88-00", 11991199),
  ("Big Bike Discounts (AW)",  55556666, "02-88-00", 11991199),
  ("Contoso Corp",             90001000, "02-99-02", 12341234),

# Create the schema
schema = StructType([ \
    StructField("SupplierName",StringType(),True), \
    StructField("TaxNumber",IntegerType(),True), \
    StructField("BankSortCode",StringType(),True), \
    StructField("BankAccountNumber", IntegerType(), True)

# Use the data array and schema to create a dataframe
df = spark.createDataFrame(data=data,schema=schema)

# We need a row id added to make building a graph simpler, so add that using a window and the row_number function
windowSpec = Window().partitionBy(lit('A')).orderBy(lit('A'))
df = df.withColumn("id", row_number().over(windowSpec))

#Finally show the schema and the results

So what have we done there? We created an array and a schema, and then used those to create a DataFrame. We added a new column to the DataFrame with a sequential number (using the window function row_number()) so we have a nice simple node id for building a graph. The name of the id column must be id in lowercase to work with the GraphFrame later.

Add another cell to to the notebook and add the following code to build the edges DataFrame.

# Create a data frame for the edges, by joining the data to itself. We do this once for the TaxNumber column
# and then again for the BankSortCode and BankAccountNumber columns, and union the results

# We are building a bidirectional graph (we only care if nodes are joined), so we make sure the left RowId > right RowId
# to remove duplicates and rows joining to themselves
taxEdgesDF = df.alias("df1") \
    .join(df.alias("df2"), \
        (col("df1.TaxNumber") == col("df2.TaxNumber")) \
        & (col("") > col("")), \
        "inner" \
    ) \
    .select(least("", "").alias("src"), \
        greatest("", "").alias("dst") \
    ) \
    .withColumn("Relationship", lit("TaxNumber"))

bankEdgesDF = df.alias("df1") \
    .join(df.alias("df2"), \
        (col("df1.BankSortCode") == col("df2.BankSortCode")) \
        & (col("df1.BankAccountNumber") == col("df2.BankAccountNumber")) \
        & (col("") > col("")), \
        "inner" \
    ) \
    .select(least("", "").alias("src"), \
        greatest("", "").alias("dst") \
    ) \
    .withColumn("Relationship", lit("BankSortCodeAndAccountNumber"))

allUniqueEdgesDF = taxEdgesDF.union(bankEdgesDF)

This is just joining the nodes to the other nodes in the same way as the join in the CTE method did. We join nodes based on the tax number, or the composite bank details columns. We limit the join to nodes where the id of the first node is larger than the id of the second node, as we don’t want nodes to join to themselves, and the edges are not directional.

The output from creating the DataFrame

Now we can add another cell to create the GraphFrame, and print some simple information about node and edge counts. Note that we have to call the cache() method on the GraphFrame before we can access the nodes and edges to count them.

from graphframes import *

connectionsGF = GraphFrame(df, allUniqueEdgesDF)

print("Total Number of Suppliers: " + str(connectionsGF.vertices.count()))
print("Total Number of Relationships/Types: " + str(connectionsGF.edges.count()))
print("Unique Number of Relationships: " + str(connectionsGF.edges.groupBy("src", "dst").count().count()))
The DataFrame for the edges shows the start and end nodes, and which join produced that relationship

Ok, so we have the GraphFrame, so how do we get the groups? Add the following code to another cell.


connectedComponents = connectionsGF.connectedComponents()
connectedComponents.where("component != 0").show()

We use the GraphFrame function connectedComponents() to get the groupings We have to set the spark context checkpoint directory before we use the connected components function. The component column is the group id from the graph function.

The groups have come back the same as for the tSQL methods

I didn’t write the results out to storage here, but the results are a normal spark DataFrame, so it would be easy enough to write them out to a parquet file in the lake.


We can see from this that if we are working with our data in Spark and Synapse SQL pools, we still have option to do associative grouping. The functions did not seem the quickest with only a small amount of data, and I’ve not tested how the performance scales, especially with multiple workers. Will we get lots of shuffling of data, I’m not sure but that sounds like something interesting to explore. It’s probably a good idea to do this kind of grouping once up front if you can though, rather than calculating it on demand, as the processing for it is likely to be quite intensive.

I did try and get the ode running in a Synapse Spark Pool notebook, but I could not get python to see the library once I had added the jar file to the workspace. It seems like Synapse is happy to use python wheels, but not jar files (Scala / Java) libs. Converting the code to Scala did work though, so maybe I’m missing something there. If you do manage to get that working, let me know in the comments.


Leave a Reply

Avatar placeholder

Your email address will not be published. Required fields are marked *