How to aggregate values into collection after groupBy?

ScalaApache SparkApache Spark-Sql

Scala Problem Overview


I have a dataframe with schema as such:

[visitorId: string, trackingIds: array<string>, emailIds: array<string>]

Looking for a way to group (or maybe rollup?) this dataframe by visitorid where the trackingIds and emailIds columns would append together. So for example if my initial df looks like:

visitorId   |trackingIds|emailIds
+-----------+------------+--------
|a158|      [666b]      |    [12]
|7g21|      [c0b5]      |    [45]
|7g21|      [c0b4]      |    [87]
|a158|      [666b, 777c]|    []

I would like my output df to look like this

visitorId   |trackingIds|emailIds
+-----------+------------+--------
|a158|      [666b,666b,777c]|      [12,'']
|7g21|      [c0b5,c0b4]     |      [45, 87]

Attempting to use groupBy and agg operators but not have much luck.

Scala Solutions


Solution 1 - Scala

Spark >= 2.4

You can replace flatten udf with built-in flatten function

import org.apache.spark.sql.functions.flatten

leaving the rest as-is.

Spark >= 2.0, < 2.4

It is possible but quite expensive. Using data you've provided:

case class Record(
    visitorId: String, trackingIds: Array[String], emailIds: Array[String])

val df = Seq(
  Record("a158", Array("666b"), Array("12")),
  Record("7g21", Array("c0b5"), Array("45")),
  Record("7g21", Array("c0b4"), Array("87")),
  Record("a158", Array("666b",  "777c"), Array.empty[String])).toDF

and a helper function:

import org.apache.spark.sql.functions.udf

val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten)

we can fill the blanks with placeholders:

import org.apache.spark.sql.functions.{array, lit, when}

val dfWithPlaceholders = df.withColumn(
  "emailIds", 
  when(size($"emailIds") === 0, array(lit(""))).otherwise($"emailIds"))

collect_lists and flatten:

import org.apache.spark.sql.functions.{array, collect_list}

val emailIds = flatten(collect_list($"emailIds")).alias("emailIds")
val trackingIds = flatten(collect_list($"trackingIds")).alias("trackingIds")

df
  .groupBy($"visitorId")
  .agg(trackingIds, emailIds)

// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     a158|[666b, 666b, 777c]|  [12, ]|
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// +---------+------------------+--------+

With statically typed Dataset:

df.as[Record]
  .groupByKey(_.visitorId)
  .mapGroups { case (key, vs) => 
    vs.map(v => (v.trackingIds, v.emailIds)).toArray.unzip match {
      case (trackingIds, emailIds) => 
        Record(key, trackingIds.flatten, emailIds.flatten)
  }}

// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     a158|[666b, 666b, 777c]|  [12, ]|
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// +---------+------------------+--------+

Spark 1.x

You can convert to RDD and group

import org.apache.spark.sql.Row

dfWithPlaceholders.rdd
  .map {
     case Row(id: String, 
       trcks: Seq[String @ unchecked],
       emails: Seq[String @ unchecked]) => (id, (trcks, emails))
  }
  .groupByKey
  .map {case (key, vs) => vs.toArray.unzip match {
    case (trackingIds, emailIds) => 
      Record(key, trackingIds.flatten, emailIds.flatten)
  }}
  .toDF
  
// +---------+------------------+--------+
// |visitorId|       trackingIds|emailIds|
// +---------+------------------+--------+
// |     7g21|      [c0b5, c0b4]|[45, 87]|
// |     a158|[666b, 666b, 777c]|  [12, ]|
// +---------+------------------+--------+



Solution 2 - Scala

@zero323's answer is pretty much complete, but Spark gives us even more flexibility. How about the following solution?

import org.apache.spark.sql.functions._
inventory
  .select($"*", explode($"trackingIds") as "tracking_id")
  .select($"*", explode($"emailIds") as "email_id")
  .groupBy("visitorId")
  .agg(
    collect_list("tracking_id") as "trackingIds",
    collect_list("email_id") as "emailIds")

That however leaves out all empty collections (so there's some room for improvement :))

Solution 3 - Scala

You can use User defined aggregated functions.

  1. create a custom UDAF using the scala class called customAggregation.

    package com.package.name

    import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import scala.collection.JavaConverters._

    class CustomAggregation() extends UserDefinedAggregateFunction {

    // Input Data Type Schema def inputSchema: StructType = StructType(Array(StructField("col5", ArrayType(StringType))))

    // Intermediate Schema def bufferSchema = StructType(Array( StructField("col5_collapsed", ArrayType(StringType))))

    // Returned Data Type . def dataType: DataType = ArrayType(StringType)

    // Self-explaining def deterministic = true

    // This function is called whenever key changes def initialize(buffer: MutableAggregationBuffer) = { buffer(0) = Array.empty[String] // initialize array }

    // Iterate over each entry of a group def update(buffer: MutableAggregationBuffer, input: Row) = { buffer(0) = if(!input.isNullAt(0)) buffer.getListString.toArray ++ input.getListString.toArray else buffer.getListString.toArray }

    // Merge two partial aggregates def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { buffer1(0) = buffer1.getListString.toArray ++ buffer2.getListString.toArray }

    // Called after all the entries are exhausted. def evaluate(buffer: Row) = { buffer.getListString.asScala.toList.distinct } }

  2. Then use the UDAF in your code as

    //define UDAF val CustomAggregation = new CustomAggregation() DataFrame .groupBy(col1,col2,col3) .agg(CustomAggregation(DataFrame(col5))).show()

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionEric PattersonView Question on Stackoverflow
Solution 1 - Scalazero323View Answer on Stackoverflow
Solution 2 - ScalaJacek LaskowskiView Answer on Stackoverflow
Solution 3 - Scalagourav sbView Answer on Stackoverflow