Spark SQL: apply aggregate functions to a list of columns

Apache SparkDataframeApache Spark-SqlAggregate Functions

Apache Spark Problem Overview


Is there a way to apply an aggregate function to all (or a list of) columns of a dataframe, when doing a groupBy? In other words, is there a way to avoid doing this for every column:

df.groupBy("col1")
  .agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...)

Apache Spark Solutions


Solution 1 - Apache Spark

There are multiple ways of applying aggregate functions to multiple columns.

GroupedData class provides a number of methods for the most common functions, including count, max, min, mean and sum, which can be used directly as follows:

  • Python:

      df = sqlContext.createDataFrame(
          [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)],
          ("col1", "col2", "col3"))
    
      df.groupBy("col1").sum()
    
      ## +----+---------+-----------------+---------+
      ## |col1|sum(col1)|        sum(col2)|sum(col3)|
      ## +----+---------+-----------------+---------+
      ## | 1.0|      2.0|              0.8|      1.0|
      ## |-1.0|     -2.0|6.199999999999999|      0.7|
      ## +----+---------+-----------------+---------+
    
  • Scala

      val df = sc.parallelize(Seq(
        (1.0, 0.3, 1.0), (1.0, 0.5, 0.0),
        (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2))
      ).toDF("col1", "col2", "col3")
    
      df.groupBy($"col1").min().show
    
      // +----+---------+---------+---------+
      // |col1|min(col1)|min(col2)|min(col3)|
      // +----+---------+---------+---------+
      // | 1.0|      1.0|      0.3|      0.0|
      // |-1.0|     -1.0|      0.6|      0.2|
      // +----+---------+---------+---------+
    

Optionally you can pass a list of columns which should be aggregated

df.groupBy("col1").sum("col2", "col3")

You can also pass dictionary / map with columns a the keys and functions as the values:

  • Python

      exprs = {x: "sum" for x in df.columns}
      df.groupBy("col1").agg(exprs).show()
    
      ## +----+---------+
      ## |col1|avg(col3)|
      ## +----+---------+
      ## | 1.0|      0.5|
      ## |-1.0|     0.35|
      ## +----+---------+
    
  • Scala

      val exprs = df.columns.map((_ -> "mean")).toMap
      df.groupBy($"col1").agg(exprs).show()
    
      // +----+---------+------------------+---------+
      // |col1|avg(col1)|         avg(col2)|avg(col3)|
      // +----+---------+------------------+---------+
      // | 1.0|      1.0|               0.4|      0.5|
      // |-1.0|     -1.0|3.0999999999999996|     0.35|
      // +----+---------+------------------+---------+
      
    

Finally you can use varargs:

  • Python

      from pyspark.sql.functions import min
    
      exprs = [min(x) for x in df.columns]
      df.groupBy("col1").agg(*exprs).show()
    
  • Scala

      import org.apache.spark.sql.functions.sum
    
      val exprs = df.columns.map(sum(_))
      df.groupBy($"col1").agg(exprs.head, exprs.tail: _*)
    

There are some other way to achieve a similar effect but these should more than enough most of the time.

See also:

Solution 2 - Apache Spark

Another example of the same concept - but say - you have 2 different columns - and you want to apply different agg functions to each of them i.e

f.groupBy("col1").agg(sum("col2").alias("col2"), avg("col3").alias("col3"), ...)

Here is the way to achieve it - though I do not yet know how to add the alias in this case

See the example below - Using Maps

val Claim1 = StructType(Seq(StructField("pid", StringType, true),StructField("diag1", StringType, true),StructField("diag2", StringType, true), StructField("allowed", IntegerType, true), StructField("allowed1", IntegerType, true)))
val claimsData1 = Seq(("PID1", "diag1", "diag2", 100, 200), ("PID1", "diag2", "diag3", 300, 600), ("PID1", "diag1", "diag5", 340, 680), ("PID2", "diag3", "diag4", 245, 490), ("PID2", "diag2", "diag1", 124, 248))

val claimRDD1 = sc.parallelize(claimsData1)
val claimRDDRow1 = claimRDD1.map(p => Row(p._1, p._2, p._3, p._4, p._5))
val claimRDD2DF1 = sqlContext.createDataFrame(claimRDDRow1, Claim1)

val l = List("allowed", "allowed1")
val exprs = l.map((_ -> "sum")).toMap
claimRDD2DF1.groupBy("pid").agg(exprs) show false
val exprs = Map("allowed" -> "sum", "allowed1" -> "avg")

claimRDD2DF1.groupBy("pid").agg(exprs) show false

Solution 3 - Apache Spark

Current answers are perfectly correct on how to create the aggregations, but none actually address the column alias/renaming that is also requested in the question.

Typically, this is how I handle this case:

val dimensionFields = List("col1")
val metrics = List("col2", "col3", "col4")
val columnOfInterests = dimensions ++ metrics

val df = spark.read.table("some_table") 
    .select(columnOfInterests.map(c => col(c)):_*)
    .groupBy(dimensions.map(d => col(d)): _*)
    .agg(metrics.map( m => m -> "sum").toMap)
    .toDF(columnOfInterests:_*)    // that's the interesting part

The last line essentially renames every columns of the aggregated dataframe to the original fields, essentially changing sum(col2) and sum(col3) to simply col2 and col3.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionlilloraffaView Question on Stackoverflow
Solution 1 - Apache Sparkzero323View Answer on Stackoverflow
Solution 2 - Apache SparkSumit PalView Answer on Stackoverflow
Solution 3 - Apache SparkPhilippe OgerView Answer on Stackoverflow