DataFrame equality in Apache Spark

ScalaApache SparkDataframeApache Spark-SqlRdd

Scala Problem Overview


Assume df1 and df2 are two DataFrames in Apache Spark, computed using two different mechanisms, e.g., Spark SQL vs. the Scala/Java/Python API.

Is there an idiomatic way to determine whether the two data frames are equivalent (equal, isomorphic), where equivalence is determined by the data (column names and column values for each row) being identical save for the ordering of rows & columns?

The motivation for the question is that there are often many ways to compute some big data result, each with its own trade-offs. As one explores these trade-offs, it is important to maintain correctness and hence the need to check for the equivalence/equality on a meaningful test data set.

Scala Solutions


Solution 1 - Scala

Scala (see below for PySpark)

The spark-fast-tests library has two methods for making DataFrame comparisons (I'm the creator of the library):

The assertSmallDataFrameEquality method collects DataFrames on the driver node and makes the comparison

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  if (!actualDF.collect().sameElements(expectedDF.collect())) {
    throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
  }
}

The assertLargeDataFrameEquality method compares DataFrames spread on multiple machines (the code is basically copied from spark-testing-base)

def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  try {
    actualDF.rdd.cache
    expectedDF.rdd.cache

    val actualCount = actualDF.rdd.count
    val expectedCount = expectedDF.rdd.count
    if (actualCount != expectedCount) {
      throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
    }

    val expectedIndexValue = zipWithIndex(actualDF.rdd)
    val resultIndexValue = zipWithIndex(expectedDF.rdd)

    val unequalRDD = expectedIndexValue
      .join(resultIndexValue)
      .filter {
        case (idx, (r1, r2)) =>
          !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
      }

    val maxUnequalRowsToShow = 10
    assertEmpty(unequalRDD.take(maxUnequalRowsToShow))

  } finally {
    actualDF.rdd.unpersist()
    expectedDF.rdd.unpersist()
  }
}

assertSmallDataFrameEquality is faster for small DataFrame comparisons and I've found it sufficient for my test suites.

PySpark

Here's a simple function that returns true if the DataFrames are equal:

def are_dfs_equal(df1, df2):
    if df1.schema != df2.schema:
        return False
    if df1.collect() != df2.collect():
        return False
    return True

You'll typically perform DataFrame equality comparisons in a test suite and will want a descriptive error message when the comparisons fail (a True / False return value doesn't help much when debugging).

Use the chispa library to access the assert_df_equality method that returns descriptive error messages for test suite workflows.

Solution 2 - Scala

There are some standard ways in the Apache Spark test suites, however most of these involve collecting the data locally and if you want to do equality testing on large DataFrames then that is likely not a suitable solution.

Checking the schema first and then you could do an intersection to df3 and verify that the count of df1,df2 & df3 are all equal (however this only works if there aren't duplicate rows, if there are different duplicates rows this method could still return true).

Another option would be getting the underlying RDDs of both of the DataFrames, mapping to (Row, 1), doing a reduceByKey to count the number of each Row, and then cogrouping the two resulting RDDs and then do a regular aggregate and return false if any of the iterators are not equal.

Solution 3 - Scala

I don't know about idiomatic, but I think you can get a robust way to compare DataFrames as you describe as follows. (I'm using PySpark for illustration, but the approach carries across languages.)

a = spark.range(5)
b = spark.range(5)

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0

This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.

For example:

a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0

This approach is quite expensive, but most of the expense is unavoidable given the need to perform a full diff. And this should scale fine as it doesn't require collecting anything locally. If you relax the constraint that the comparison should account for duplicate rows, then you can drop the groupBy() and just do the subtract(), which would probably speed things up notably.

Solution 4 - Scala

Java:

assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();

Solution 5 - Scala

Try doing the following:

df1.except(df2).isEmpty

Solution 6 - Scala

A scalable and easy way is to diff the two DataFrames and count the non-matching rows:

df1.diff(df2).where($"diff" != "N").count

If that number is not zero, then the two DataFrames are not equivalent.

The diff transformation is provided by spark-extension.

It identifies Inserted, Changed, Deleted and uN-changed rows.

Solution 7 - Scala

There are 4 Options depending on whether you have duplicate rows or not.

Let's say we have two DataFrames, z1 and z1. Option 1/2 are good for rows without duplicates. You can try these in spark-shell.

  • Option 1: do except directly
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column

def isEqual(left: DataFrame, right: DataFrame): Boolean = {
   if(left.columns.length != right.columns.length) return false // column lengths don't match
   if(left.count != right.count) return false // record count don't match
   return left.except(right).isEmpty && right.except(left).isEmpty
}
  • Option 2: generate row hash by columns
def createHashColumn(df: DataFrame) : Column = {
   val colArr = df.columns
   md5(concat_ws("", (colArr.map(col(_))) : _*))
}

val z1SigDF = z1.select(col("index"), createHashColumn(z1).as("signature_z1"))
val z2SigDF = z2.select(col("index"), createHashColumn(z2).as("signature_z2"))
val joinDF = z1SigDF.join(z2SigDF, z1SigDF("index") === z2SigDF("index")).where($"signature_z1" =!= $"signature_z2").cache
// should be 0
joinDF.count
  • Option 3: use GroupBy(for DataFrame with duplicate rows)
val z1Grouped = z1.groupBy(z1.columns.map(c => z1(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")
val z2Grouped = z2.groupBy(z2.columns.map(c => z2(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")

val inZ1NotInZ2 = z1Grouped.except(z2Grouped).toDF()
val inZ2NotInZ1 = z2Grouped.except(z1Grouped).toDF()
// both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show
  • Option 4, use exceptAll, which should also work for data with duplicate rows
// Source Code: https://github.com/apache/spark/blob/50538600ec/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2029
val inZ1NotInZ2 = z1.exceptAll(z2).toDF()
val inZ2NotInZ1 = z2.exceptAll(z1).toDF()
// same here, // both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show

Solution 8 - Scala

You can do this using a little bit of deduplication in combination with a full outer join. The advantage of this approach is that it does not require you to collect results to the driver, and that it avoids running multiple jobs.

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

// Generate some random data.
def random(n: Int, s: Long) = {
  spark.range(n).select(
    (rand(s) * 10000).cast("int").as("a"),
    (rand(s + 5) * 1000).cast("int").as("b"))
}
val df1 = random(10000000, 34)
val df2 = random(10000000, 17)

// Move all the keys into a struct (to make handling nulls easy), deduplicate the given dataset
// and count the rows per key.
def dedup(df: Dataset[Row]): Dataset[Row] = {
  df.select(struct(df.columns.map(col): _*).as("key"))
    .groupBy($"key")
    .agg(count(lit(1)).as("row_count"))
}

// Deduplicate the inputs and join them using a full outer join. The result can contain
// the following things:
// 1. Both keys are not null (and thus equal), and the row counts are the same. The dataset
//    is the same for the given key.
// 2. Both keys are not null (and thus equal), and the row counts are not the same. The dataset
//    contains the same keys.
// 3. Only the right key is not null.
// 4. Only the left key is not null.
val joined = dedup(df1).as("l").join(dedup(df2).as("r"), $"l.key" === $"r.key", "full")

// Summarize the differences.
val summary = joined.select(
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" === $"l.row_count", 1)).as("left_right_same_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" =!= $"l.row_count", 1)).as("left_right_different_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNull, 1)).as("left_only"),
  count(when($"l.key".isNull && $"r.key".isNotNull, 1)).as("right_only"))
summary.show()

Solution 9 - Scala

try {
  return ds1.union(ds2)
          .groupBy(columns(ds1, ds1.columns()))
          .count()
          .filter("count % 2 > 0")
          .count()
      == 0;
} catch (Exception e) {
  return false;
}

Column[] columns(Dataset<Row> ds, String... columnNames) {
List<Column> l = new ArrayList<>();
for (String cn : columnNames) {
  l.add(ds.col(cn));
}
return l.stream().toArray(Column[]::new);}

columns method is supplementary and can be replaced by any method that returns Seq

Logic:

  1. Union both the datasets, if columns are not matching, it will throw an exception and hence return false.
  2. If columns are matching then groupBy on all columns and add a column count. Now, all the rows have count in the multiple of 2 (even for duplicate rows).
  3. Check if there is any row that has count not divisible by 2, those are the extra rows.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionSimView Question on Stackoverflow
Solution 1 - ScalaPowersView Answer on Stackoverflow
Solution 2 - ScalaHoldenView Answer on Stackoverflow
Solution 3 - ScalaNick ChammasView Answer on Stackoverflow
Solution 4 - Scalauser1442346View Answer on Stackoverflow
Solution 5 - ScalaAsher AView Answer on Stackoverflow
Solution 6 - ScalaEnricoMView Answer on Stackoverflow
Solution 7 - ScalaLeOn - Han LiView Answer on Stackoverflow
Solution 8 - ScalaHerman van HovellView Answer on Stackoverflow
Solution 9 - ScalaJ. PView Answer on Stackoverflow