Querying Spark SQL DataFrame with complex types

SqlScalaApache SparkDataframeApache Spark-Sql

Sql Problem Overview


How Can I query an RDD with complex types such as maps/arrays? for example, when I was writing this test code:

case class Test(name: String, map: Map[String, String])
val map = Map("hello" -> "world", "hey" -> "there")
val map2 = Map("hello" -> "people", "hey" -> "you")
val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))

I thought the syntax would be something like:

sqlContext.sql("SELECT * FROM rdd WHERE map.hello = world")

or

sqlContext.sql("SELECT * FROM rdd WHERE map[hello] = world")

but I get

> Can't access nested field in type MapType(StringType,StringType,true)

and

> org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Unresolved attributes

respectively.

Sql Solutions


Solution 1 - Sql

It depends on a type of the column. Lets start with some dummy data:

import org.apache.spark.sql.functions.{udf, lit}
import scala.util.Try

case class SubRecord(x: Int)
case class ArrayElement(foo: String, bar: Int, vals: Array[Double])
case class Record(
  an_array: Array[Int], a_map: Map[String, String], 
  a_struct: SubRecord, an_array_of_structs: Array[ArrayElement])


val df = sc.parallelize(Seq(
  Record(Array(1, 2, 3), Map("foo" -> "bar"), SubRecord(1),
         Array(
           ArrayElement("foo", 1, Array(1.0, 2.0, 2.0)),
           ArrayElement("bar", 2, Array(3.0, 4.0, 5.0)))),
  Record(Array(4, 5, 6), Map("foz" -> "baz"), SubRecord(2),
         Array(ArrayElement("foz", 3, Array(5.0, 6.0)), 
               ArrayElement("baz", 4, Array(7.0, 8.0))))
)).toDF

df.registerTempTable("df")
df.printSchema

// root
// |-- an_array: array (nullable = true)
// |    |-- element: integer (containsNull = false)
// |-- a_map: map (nullable = true)
// |    |-- key: string
// |    |-- value: string (valueContainsNull = true)
// |-- a_struct: struct (nullable = true)
// |    |-- x: integer (nullable = false)
// |-- an_array_of_structs: array (nullable = true)
// |    |-- element: struct (containsNull = true)
// |    |    |-- foo: string (nullable = true)
// |    |    |-- bar: integer (nullable = false)
// |    |    |-- vals: array (nullable = true)
// |    |    |    |-- element: double (containsNull = false)
  • array (ArrayType) columns:

    • Column.getItem method

          df.select($"an_array".getItem(1)).show
      
          // +-----------+
          // |an_array[1]|
          // +-----------+
          // |          2|
          // |          5|
          // +-----------+
      
    • Hive brackets syntax:

          sqlContext.sql("SELECT an_array[1] FROM df").show
      
          // +---+
          // |_c0|
          // +---+
          // |  2|
          // |  5|
          // +---+
      
    • an UDF

          val get_ith = udf((xs: Seq[Int], i: Int) => Try(xs(i)).toOption)
      
          df.select(get_ith($"an_array", lit(1))).show
      
          // +---------------+
          // |UDF(an_array,1)|
          // +---------------+
          // |              2|
          // |              5|
          // +---------------+
      
    • Additionally to the methods listed above Spark supports a growing list of built-in functions operating on complex types. Notable examples include higher order functions like transform (SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):

          df.selectExpr("transform(an_array, x -> x + 1) an_array_inc").show
          // +------------+
          // |an_array_inc|
          // +------------+
          // |   [2, 3, 4]|
          // |   [5, 6, 7]|
          // +------------+
      
          import org.apache.spark.sql.functions.transform
      
          df.select(transform($"an_array", x => x + 1) as "an_array_inc").show
          // +------------+
          // |an_array_inc|
          // +------------+
          // |   [2, 3, 4]|
          // |   [5, 6, 7]|
          // +------------+
      
    • filter (SQL 2.4+, Scala 3.0+, Python / SparkR 3.1+)

         df.selectExpr("filter(an_array, x -> x % 2 == 0) an_array_even").show
         // +-------------+
         // |an_array_even|
         // +-------------+
         // |          [2]|
         // |       [4, 6]|
         // +-------------+
      
         import org.apache.spark.sql.functions.filter
      
         df.select(filter($"an_array", x => x % 2 === 0) as "an_array_even").show
         // +-------------+
         // |an_array_even|
         // +-------------+
         // |          [2]|
         // |       [4, 6]|
         // +-------------+
      
    • aggregate (SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):

         df.selectExpr("aggregate(an_array, 0, (acc, x) -> acc + x, acc -> acc) an_array_sum").show
         // +------------+
         // |an_array_sum|
         // +------------+
         // |           6|
         // |          15|
         // +------------+
      
         import org.apache.spark.sql.functions.aggregate
      
         df.select(aggregate($"an_array", lit(0), (x, y) => x + y) as "an_array_sum").show
         // +------------+                                                                  
         // |an_array_sum|
         // +------------+
         // |           6|
         // |          15|
         // +------------+
      
    • array processing functions (array_*) like array_distinct (2.4+):

         import org.apache.spark.sql.functions.array_distinct
      
         df.select(array_distinct($"an_array_of_structs.vals"(0))).show
         // +-------------------------------------------+
         // |array_distinct(an_array_of_structs.vals[0])|
         // +-------------------------------------------+
         // |                                 [1.0, 2.0]|
         // |                                 [5.0, 6.0]|
         // +-------------------------------------------+
      
    • array_max (array_min, 2.4+):

         import org.apache.spark.sql.functions.array_max
      
         df.select(array_max($"an_array")).show
         // +-------------------+
         // |array_max(an_array)|
         // +-------------------+
         // |                  3|
         // |                  6|
         // +-------------------+
      
    • flatten (2.4+)

         import org.apache.spark.sql.functions.flatten
      
         df.select(flatten($"an_array_of_structs.vals")).show
         // +---------------------------------+
         // |flatten(an_array_of_structs.vals)|
         // +---------------------------------+
         // |             [1.0, 2.0, 2.0, 3...|
         // |             [5.0, 6.0, 7.0, 8.0]|
         // +---------------------------------+
      
    • arrays_zip (2.4+):

         import org.apache.spark.sql.functions.arrays_zip
      
         df.select(arrays_zip($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show(false)
         // +--------------------------------------------------------------------+
         // |arrays_zip(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
         // +--------------------------------------------------------------------+
         // |[[1.0, 3.0], [2.0, 4.0], [2.0, 5.0]]                                |
         // |[[5.0, 7.0], [6.0, 8.0]]                                            |
         // +--------------------------------------------------------------------+
      
    • array_union (2.4+):

         import org.apache.spark.sql.functions.array_union
      
         df.select(array_union($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show
         // +---------------------------------------------------------------------+
         // |array_union(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
         // +---------------------------------------------------------------------+
         // |                                                 [1.0, 2.0, 3.0, 4...|
         // |                                                 [5.0, 6.0, 7.0, 8.0]|
         // +---------------------------------------------------------------------+
      
    • slice (2.4+):

        import org.apache.spark.sql.functions.slice
      
        df.select(slice($"an_array", 2, 2)).show
        // +---------------------+
        // |slice(an_array, 2, 2)|
        // +---------------------+
        // |               [2, 3]|
        // |               [5, 6]|
        // +---------------------+
      
  • map (MapType) columns

    • using Column.getField method:

          df.select($"a_map".getField("foo")).show
      
          // +----------+
          // |a_map[foo]|
          // +----------+
          // |       bar|
          // |      null|
          // +----------+
      
    • using Hive brackets syntax:

          sqlContext.sql("SELECT a_map['foz'] FROM df").show
      
          // +----+
          // | _c0|
          // +----+
          // |null|
          // | baz|
          // +----+
      
    • using a full path with dot syntax:

          df.select($"a_map.foo").show
      
          // +----+
          // | foo|
          // +----+
          // | bar|
          // |null|
          // +----+
      
    • using an UDF

          val get_field = udf((kvs: Map[String, String], k: String) => kvs.get(k))
      
          df.select(get_field($"a_map", lit("foo"))).show
      
          // +--------------+
          // |UDF(a_map,foo)|
          // +--------------+
          // |           bar|
          // |          null|
          // +--------------+
      
    • Growing number of map_* functions like map_keys (2.3+)

         import org.apache.spark.sql.functions.map_keys
      
         df.select(map_keys($"a_map")).show
         // +---------------+
         // |map_keys(a_map)|
         // +---------------+
         // |          [foo]|
         // |          [foz]|
         // +---------------+
      
    • or map_values (2.3+)

         import org.apache.spark.sql.functions.map_values
      
         df.select(map_values($"a_map")).show
         // +-----------------+
         // |map_values(a_map)|
         // +-----------------+
         // |            [bar]|
         // |            [baz]|
         // +-----------------+
      

    Please check SPARK-23899 for a detailed list.

  • struct (StructType) columns using full path with dot syntax:

    • with DataFrame API

          df.select($"a_struct.x").show
      
          // +---+
          // |  x|
          // +---+
          // |  1|
          // |  2|
          // +---+
      
    • with raw SQL

          sqlContext.sql("SELECT a_struct.x FROM df").show
      
          // +---+
          // |  x|
          // +---+
          // |  1|
          // |  2|
          // +---+
      
  • fields inside array of structs can be accessed using dot-syntax, names and standard Column methods:

      df.select($"an_array_of_structs.foo").show
    
      // +----------+
      // |       foo|
      // +----------+
      // |[foo, bar]|
      // |[foz, baz]|
      // +----------+
    
      sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show
    
      // +---+
      // |_c0|
      // +---+
      // |foo|
      // |foz|
      // +---+
    
      df.select($"an_array_of_structs.vals".getItem(1).getItem(1)).show
    
      // +------------------------------+
      // |an_array_of_structs.vals[1][1]|
      // +------------------------------+
      // |                           4.0|
      // |                           8.0|
      // +------------------------------+
    
  • user defined types (UDTs) fields can be accessed using UDFs. See https://stackoverflow.com/q/33747851/1560062 for details.

Notes:

  • depending on a Spark version some of these methods can be available only with HiveContext. UDFs should work independent of version with both standard SQLContext and HiveContext.

  • generally speaking nested values are a second class citizens. Not all typical operations are supported on nested fields. Depending on a context it could be better to flatten the schema and / or explode collections

      df.select(explode($"an_array_of_structs")).show
    
      // +--------------------+
      // |                 col|
      // +--------------------+
      // |[foo,1,WrappedArr...|
      // |[bar,2,WrappedArr...|
      // |[foz,3,WrappedArr...|
      // |[baz,4,WrappedArr...|
      // +--------------------+
    
  • Dot syntax can be combined with wildcard character (*) to select (possibly multiple) fields without specifying names explicitly:

      df.select($"a_struct.*").show
      // +---+
      // |  x|
      // +---+
      // |  1|
      // |  2|
      // +---+
    
  • JSON columns can be queried using get_json_object and from_json functions. See https://stackoverflow.com/q/34069282/ for details.

Solution 2 - Sql

Once You convert it to DF, u can simply fetch data as

  val rddRow= rdd.map(kv=>{
    val k = kv._1
    val v = kv._2
    Row(k, v)
  })

val myFld1 =  StructField("name", org.apache.spark.sql.types.StringType, true)
val myFld2 =  StructField("map", org.apache.spark.sql.types.MapType(StringType, StringType), true)
val arr = Array( myFld1, myFld2)
val schema = StructType( arr )
val rowrddDF = sqc.createDataFrame(rddRow, schema)
rowrddDF.registerTempTable("rowtbl")  
val rowrddDFFinal = rowrddDF.select(rowrddDF("map.one"))
or
val rowrddDFFinal = rowrddDF.select("map.one")

Solution 3 - Sql

here was what I did and it worked

case class Test(name: String, m: Map[String, String])
val map = Map("hello" -> "world", "hey" -> "there")
val map2 = Map("hello" -> "people", "hey" -> "you")
val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))
val rdddf = rdd.toDF
rdddf.registerTempTable("mytable")
sqlContext.sql("select m.hello from mytable").show

Results

+------+
| hello|
+------+
| world|
|people|
+------+

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestiondvirView Question on Stackoverflow
Solution 1 - Sqlzero323View Answer on Stackoverflow
Solution 2 - SqlsshroffView Answer on Stackoverflow
Solution 3 - SqlSumit PalView Answer on Stackoverflow