Automatically and Elegantly flatten DataFrame in Spark SQL

ScalaApache SparkApache Spark-Sql

Scala Problem Overview


All,

Is there an elegant and accepted way to flatten a Spark SQL table (Parquet) with columns that are of nested StructType

For example

If my schema is:

foo
 |_bar
 |_baz
x
y
z

How do I select it into a flattened tabular form without resorting to manually running

df.select("foo.bar","foo.baz","x","y","z")

In other words, how do I obtain the result of the above code programmatically given just a StructType and a DataFrame

Scala Solutions


Solution 1 - Scala

The short answer is, there's no "accepted" way to do this, but you can do it very elegantly with a recursive function that generates your select(...) statement by walking through the DataFrame.schema.

The recursive function should return an Array[Column]. Every time the function hits a StructType, it would call itself and append the returned Array[Column] to its own Array[Column].

Something like:

import org.apache.spark.sql.Column
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.col

def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
  schema.fields.flatMap(f => {
    val colName = if (prefix == null) f.name else (prefix + "." + f.name)

    f.dataType match {
      case st: StructType => flattenSchema(st, colName)
      case _ => Array(col(colName))
    }
  })
}

You would then use it like this:

df.select(flattenSchema(df.schema):_*)

Solution 2 - Scala

Just wanted to share my solution for Pyspark - it's more or less a translation of @David Griffin's solution, so it supports any level of nested objects.

from pyspark.sql.types import StructType, ArrayType  
  
def flatten(schema, prefix=None):
    fields = []
    for field in schema.fields:
        name = prefix + '.' + field.name if prefix else field.name
        dtype = field.dataType
        if isinstance(dtype, ArrayType):
            dtype = dtype.elementType

        if isinstance(dtype, StructType):
            fields += flatten(dtype, prefix=name)
        else:
            fields.append(name)

    return fields

    
df.select(flatten(df.schema)).show()

Solution 3 - Scala

I am improving my previous answer and offering a solution to my own problem stated in the comments of the accepted answer.

This accepted solution creates an array of Column objects and uses it to select these columns. In Spark, if you have a nested DataFrame, you can select the child column like this: df.select("Parent.Child") and this returns a DataFrame with the values of the child column and is named Child. But if you have identical names for attributes of different parent structures, you lose the info about the parent and may end up with identical column names and cannot access them by name anymore as they are unambiguous.

This was my problem.

I found a solution to my problem, maybe it can help someone else as well. I called the flattenSchema separately:

val flattenedSchema = flattenSchema(df.schema)

and this returned an Array of Column objects. Instead of using this in the select(), which would return a DataFrame with columns named by the child of the last level, I mapped the original column names to themselves as strings, then after selecting Parent.Child column, it renames it as Parent.Child instead of Child (I also replaced dots with underscores for my convenience):

val renamedCols = flattenedSchema.map(name => col(name.toString()).as(name.toString().replace(".","_")))

And then you can use the select function as shown in the original answer:

var newDf = df.select(renamedCols:_*)

Solution 4 - Scala

========== edit ====

There's some additional handling for more complex schemas here: https://medium.com/@lvhuyen/working-with-spark-dataframe-having-a-complex-schema-a3bce8c3f44

==================

PySpark, added to @Evan V's answer, when your field-names have special characters, like a dot '.', a hyphen '-', ...:

from pyspark.sql.types import StructType, ArrayType  

def normalise_field(raw):
    return raw.strip().lower() \
            .replace('`', '') \
            .replace('-', '_') \
            .replace(' ', '_') \
            .strip('_')

def flatten(schema, prefix=None):
    fields = []
    for field in schema.fields:
        name = "%s.`%s`" % (prefix, field.name) if prefix else "`%s`" % field.name
        dtype = field.dataType
        if isinstance(dtype, ArrayType):
            dtype = dtype.elementType
        if isinstance(dtype, StructType):
            fields += flatten(dtype, prefix=name)
        else:
            fields.append(col(name).alias(normalise_field(name)))

    return fields

df.select(flatten(df.schema)).show()

Solution 5 - Scala

You could also use SQL to select columns as flat.

  1. Get original data-frame schema
  2. Generate SQL string, by browsing schema
  3. Query your original data-frame

I did an implementation in Java: https://gist.github.com/ebuildy/3de0e2855498e5358e4eed1a4f72ea48

(use recursive method as well, I prefer SQL way, so you can test it easily via Spark-shell).

Solution 6 - Scala

I added a DataFrame#flattenSchema method to the open source spark-daria project.

Here's how you can use the function with your code.

import com.github.mrpowers.spark.daria.sql.DataFrameExt._
df.flattenSchema().show()

+-------+-------+---------+----+---+
|foo.bar|foo.baz|        x|   y|  z|
+-------+-------+---------+----+---+
|   this|     is|something|cool| ;)|
+-------+-------+---------+----+---+

You can also specify different column name delimiters with the flattenSchema() method.

df.flattenSchema(delimiter = "_").show()
+-------+-------+---------+----+---+
|foo_bar|foo_baz|        x|   y|  z|
+-------+-------+---------+----+---+
|   this|     is|something|cool| ;)|
+-------+-------+---------+----+---+

This delimiter parameter is surprisingly important. If you're flattening your schema to load the table in Redshift, you won't be able to use periods as the delimiter.

Here's the full code snippet to generate this output.

val data = Seq(
  Row(Row("this", "is"), "something", "cool", ";)")
)

val schema = StructType(
  Seq(
    StructField(
      "foo",
      StructType(
        Seq(
          StructField("bar", StringType, true),
          StructField("baz", StringType, true)
        )
      ),
      true
    ),
    StructField("x", StringType, true),
    StructField("y", StringType, true),
    StructField("z", StringType, true)
  )
)

val df = spark.createDataFrame(
  spark.sparkContext.parallelize(data),
  StructType(schema)
)

df.flattenSchema().show()

The underlying code is similar to David Griffin's code (in case you don't want to add the spark-daria dependency to your project).

object StructTypeHelpers {

  def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = {
    schema.fields.flatMap(structField => {
      val codeColName = if (prefix == null) structField.name else prefix + "." + structField.name
      val colName = if (prefix == null) structField.name else prefix + delimiter + structField.name

      structField.dataType match {
        case st: StructType => flattenSchema(schema = st, delimiter = delimiter, prefix = colName)
        case _ => Array(col(codeColName).alias(colName))
      }
    })
  }

}

object DataFrameExt {

  implicit class DataFrameMethods(df: DataFrame) {

    def flattenSchema(delimiter: String = ".", prefix: String = null): DataFrame = {
      df.select(
        StructTypeHelpers.flattenSchema(df.schema, delimiter, prefix): _*
      )
    }

  }

}

Solution 7 - Scala

Here is a function that is doing what you want and that can deal with multiple nested columns containing columns with same name, with a prefix:

from pyspark.sql import functions as F

def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']

    flat_df = nested_df.select(flat_cols +
                               [F.col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols
                                for c in nested_df.select(nc+'.*').columns])
    return flat_df

Before:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- bar: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)

After:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo_a: float (nullable = true)
 |-- foo_b: float (nullable = true)
 |-- foo_c: integer (nullable = true)
 |-- bar_a: float (nullable = true)
 |-- bar_b: float (nullable = true)
 |-- bar_c: integer (nullable = true)

Solution 8 - Scala

To combine David Griffen and V. Samma answers, you could just do this to flatten while avoiding duplicate column names:

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame

def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
  schema.fields.flatMap(f => {
    val colName = if (prefix == null) f.name else (prefix + "." + f.name)
    f.dataType match {
      case st: StructType => flattenSchema(st, colName)
      case _ => Array(col(colName).as(colName.replace(".","_")))
    }
  })
}

def flattenDataFrame(df:DataFrame): DataFrame = {
    df.select(flattenSchema(df.schema):_*)
}

var my_flattened_json_table = flattenDataFrame(my_json_table)

Solution 9 - Scala

A little addition to the code above, if you are working with Nested Struct and Array.

def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
    schema.fields.flatMap(f => {
      val colName = if (prefix == null) f.name else (prefix + "." + f.name)

      f match {
        case StructField(_, struct:StructType, _, _) => flattenSchema(struct, colName)
        case StructField(_, ArrayType(x :StructType, _), _, _) => flattenSchema(x, colName)
        case StructField(_, ArrayType(_, _), _, _) => Array(col(colName))
        case _ => Array(col(colName))
      }
    })
  }

Solution 10 - Scala

I have been using one liners which result in a flattened schema with 5 columns of bar, baz, x, y, z:

df.select("foo.*", "x", "y", "z")

As for explode: I typically reserve explode for flattening a list. For example if you have a column idList that is a list of Strings, you could do:

df.withColumn("flattenedId", functions.explode(col("idList")))
  .drop("idList")

That will result in a new Dataframe with a column named flattenedId (no longer a list)

Solution 11 - Scala

This is a modification of the solution but it uses tailrec notation


  @tailrec
  def flattenSchema(
      splitter: String,
      fields: List[(StructField, String)],
      acc: Seq[Column]): Seq[Column] = {
    fields match {
      case (field, prefix) :: tail if field.dataType.isInstanceOf[StructType] =>
        val newPrefix = s"$prefix${field.name}."
        val newFields = field.dataType.asInstanceOf[StructType].fields.map((_, newPrefix)).toList
        flattenSchema(splitter, tail ++ newFields, acc)

      case (field, prefix) :: tail =>
        val colName = s"$prefix${field.name}"
        val newCol  = col(colName).as(colName.replace(".", splitter))
        flattenSchema(splitter, tail, acc :+ newCol)

      case _ => acc
    }
  }
  def flattenDataFrame(df: DataFrame): DataFrame = {
    val fields = df.schema.fields.map((_, ""))
    df.select(flattenSchema("__", fields.toList, Seq.empty): _*)
  }

Solution 12 - Scala

This is based on @Evan V's solution to deal with more heavily nested Json files. For me the problem with original solution is When there is an ArrayType nested right in another ArrayType, I got an error.

for example if a Json looks like:

{"e":[{"f":[{"g":"h"}]}]}

I will get an error:

"cannot resolve '`e`.`f`['g']' due to data type mismatch: argument 2 requires integral type

To solve this I modified the code a bit, I agree this looks super stupid bust just posting it here so that someone may come up with a nicer solution.

def flatten(schema, prefix=None):
    fields = []
    for field in schema.fields:
        name = prefix + '.' + field.name if prefix else field.name
        dtype = field.dataType
        if isinstance(dtype, T.StructType):
            fields += flatten(dtype, prefix=name)
        else:
            fields.append(name)

    return fields


def explodeDF(df):
    for (name, dtype) in df.dtypes:
        if "array" in dtype:
            df = df.withColumn(name, F.explode(name))
    
    return df

def df_is_flat(df):
    for (_, dtype) in df.dtypes:
        if ("array" in dtype) or ("struct" in dtype):
            return False
    
    return True

def flatJson(jdf):
    keepGoing = True
    while(keepGoing):
        fields = flatten(jdf.schema)
        new_fields = [item.replace(".", "_") for item in fields]
        jdf = jdf.select(fields).toDF(*new_fields)
        jdf = explodeDF(jdf)
        if df_is_flat(jdf):
            keepGoing = False
    
    return jdf

Usage:

df = spark.read.json(path_to_json)
flat_df = flatJson(df)

flat_df.show()
+---+---+-----+
|  a|e_c|e_f_g|
+---+---+-----+
|  b|  d|    h|
+---+---+-----+

Solution 13 - Scala

import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.StructType
import scala.collection.mutable.ListBuffer 
val columns=new ListBuffer[String]()

def flattenSchema(schema:StructType,prefix:String=null){
for(i<-schema.fields){
  if(i.dataType.isInstanceOf[StructType]) {
    val columnPrefix = i.name + "."
    flattenSchema(i.dataType.asInstanceOf[StructType], columnPrefix)
  }
  else {
    if(prefix == null)
      columns.+=(i.name)
    else
      columns.+=(prefix+i.name)
  }
  }
}

Solution 14 - Scala

Combining Evan V's, Avrell and Steco ideas. I am also providing a complete SQL syntax while handling query fields with special characters using '`' in PySpark.

The solution below gives the following,

  1. Handle Nested JSON Schema.
  2. Handle same column names across nested columns (We will give alias name of the entire hierarchy separated by underscores).
  3. Handle Special Characters. (we handle special characters with '', I have not handled consecutive occurences of '' but we can do that as well with appropriate 'sub' replacements)
  4. Gives us SQL syntax.
  5. Query Fields are enclosed within '`'.

Code snippet is below,

df=spark.read.json('<JSON FOLDER / FILE PATH>')
df.printSchema()
from pyspark.sql.types import StructType, ArrayType

def flatten(schema, prefix=None):
    fields = []
    for field in schema.fields:
        name = prefix + '.' + field.name if prefix else field.name
        dtype = field.dataType
        if isinstance(dtype, ArrayType):
            dtype = dtype.elementType
        
        if isinstance(dtype, StructType):
            fields += flatten(dtype, prefix=name)
        else:
            alias_name=name.replace('.','_').replace(' ','_').replace('(','').replace(')','').replace('-','_').replace('&','_').replace(r'(_){2,}',r'\1')
            name=name.replace('.','`.`')
            field_name = "`" + name + "`" + " AS " + alias_name
            fields.append(field_name)
    return fields

df.createOrReplaceTempView("to_flatten_df")
query_fields=flatten(df.schema)

def listToString(s):  
    
    # initialize an empty string 
    str1 = ""
    # traverse in the string   
    for ele in s:  
        str1 = str1 + ele + ','
    # return string   
    return str1  

spark.sql("SELECT " + listToString(query_fields)[:-1] + " FROM to_flatten_df" ).show()

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionechenView Question on Stackoverflow
Solution 1 - ScalaDavid GriffinView Answer on Stackoverflow
Solution 2 - ScalaEvan VView Answer on Stackoverflow
Solution 3 - ScalaV. SammaView Answer on Stackoverflow
Solution 4 - ScalaAverellView Answer on Stackoverflow
Solution 5 - ScalaThomas DecauxView Answer on Stackoverflow
Solution 6 - ScalaPowersView Answer on Stackoverflow
Solution 7 - ScalastecoView Answer on Stackoverflow
Solution 8 - ScalaswdevView Answer on Stackoverflow
Solution 9 - ScalaBabatunde AdekunleView Answer on Stackoverflow
Solution 10 - ScalaKei-venView Answer on Stackoverflow
Solution 11 - ScalafhuertasView Answer on Stackoverflow
Solution 12 - ScalaWEIHANG LIUView Answer on Stackoverflow
Solution 13 - ScalaIshan KumarView Answer on Stackoverflow
Solution 14 - ScalaviswanathanrView Answer on Stackoverflow