Automatically and Elegantly flatten DataFrame in Spark SQL
ScalaApache SparkApache Spark-SqlScala Problem Overview
All,
Is there an elegant and accepted way to flatten a Spark SQL table (Parquet) with columns that are of nested StructType
For example
If my schema is:
foo
|_bar
|_baz
x
y
z
How do I select it into a flattened tabular form without resorting to manually running
df.select("foo.bar","foo.baz","x","y","z")
In other words, how do I obtain the result of the above code programmatically given just a StructType
and a DataFrame
Scala Solutions
Solution 1 - Scala
The short answer is, there's no "accepted" way to do this, but you can do it very elegantly with a recursive function that generates your select(...)
statement by walking through the DataFrame.schema
.
The recursive function should return an Array[Column]
. Every time the function hits a StructType
, it would call itself and append the returned Array[Column]
to its own Array[Column]
.
Something like:
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.col
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f.dataType match {
case st: StructType => flattenSchema(st, colName)
case _ => Array(col(colName))
}
})
}
You would then use it like this:
df.select(flattenSchema(df.schema):_*)
Solution 2 - Scala
Just wanted to share my solution for Pyspark - it's more or less a translation of @David Griffin's solution, so it supports any level of nested objects.
from pyspark.sql.types import StructType, ArrayType
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields
df.select(flatten(df.schema)).show()
Solution 3 - Scala
I am improving my previous answer and offering a solution to my own problem stated in the comments of the accepted answer.
This accepted solution creates an array of Column objects and uses it to select these columns. In Spark, if you have a nested DataFrame, you can select the child column like this: df.select("Parent.Child")
and this returns a DataFrame with the values of the child column and is named Child. But if you have identical names for attributes of different parent structures, you lose the info about the parent and may end up with identical column names and cannot access them by name anymore as they are unambiguous.
This was my problem.
I found a solution to my problem, maybe it can help someone else as well. I called the flattenSchema
separately:
val flattenedSchema = flattenSchema(df.schema)
and this returned an Array of Column objects. Instead of using this in the select()
, which would return a DataFrame with columns named by the child of the last level, I mapped the original column names to themselves as strings, then after selecting Parent.Child
column, it renames it as Parent.Child
instead of Child
(I also replaced dots with underscores for my convenience):
val renamedCols = flattenedSchema.map(name => col(name.toString()).as(name.toString().replace(".","_")))
And then you can use the select function as shown in the original answer:
var newDf = df.select(renamedCols:_*)
Solution 4 - Scala
========== edit ====
There's some additional handling for more complex schemas here: https://medium.com/@lvhuyen/working-with-spark-dataframe-having-a-complex-schema-a3bce8c3f44
==================
PySpark, added to @Evan V's answer, when your field-names have special characters, like a dot '.', a hyphen '-', ...:
from pyspark.sql.types import StructType, ArrayType
def normalise_field(raw):
return raw.strip().lower() \
.replace('`', '') \
.replace('-', '_') \
.replace(' ', '_') \
.strip('_')
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = "%s.`%s`" % (prefix, field.name) if prefix else "`%s`" % field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(col(name).alias(normalise_field(name)))
return fields
df.select(flatten(df.schema)).show()
Solution 5 - Scala
You could also use SQL to select columns as flat.
- Get original data-frame schema
- Generate SQL string, by browsing schema
- Query your original data-frame
I did an implementation in Java: https://gist.github.com/ebuildy/3de0e2855498e5358e4eed1a4f72ea48
(use recursive method as well, I prefer SQL way, so you can test it easily via Spark-shell).
Solution 6 - Scala
I added a DataFrame#flattenSchema
method to the open source spark-daria project.
Here's how you can use the function with your code.
import com.github.mrpowers.spark.daria.sql.DataFrameExt._
df.flattenSchema().show()
+-------+-------+---------+----+---+
|foo.bar|foo.baz| x| y| z|
+-------+-------+---------+----+---+
| this| is|something|cool| ;)|
+-------+-------+---------+----+---+
You can also specify different column name delimiters with the flattenSchema()
method.
df.flattenSchema(delimiter = "_").show()
+-------+-------+---------+----+---+
|foo_bar|foo_baz| x| y| z|
+-------+-------+---------+----+---+
| this| is|something|cool| ;)|
+-------+-------+---------+----+---+
This delimiter parameter is surprisingly important. If you're flattening your schema to load the table in Redshift, you won't be able to use periods as the delimiter.
Here's the full code snippet to generate this output.
val data = Seq(
Row(Row("this", "is"), "something", "cool", ";)")
)
val schema = StructType(
Seq(
StructField(
"foo",
StructType(
Seq(
StructField("bar", StringType, true),
StructField("baz", StringType, true)
)
),
true
),
StructField("x", StringType, true),
StructField("y", StringType, true),
StructField("z", StringType, true)
)
)
val df = spark.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(schema)
)
df.flattenSchema().show()
The underlying code is similar to David Griffin's code (in case you don't want to add the spark-daria dependency to your project).
object StructTypeHelpers {
def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = {
schema.fields.flatMap(structField => {
val codeColName = if (prefix == null) structField.name else prefix + "." + structField.name
val colName = if (prefix == null) structField.name else prefix + delimiter + structField.name
structField.dataType match {
case st: StructType => flattenSchema(schema = st, delimiter = delimiter, prefix = colName)
case _ => Array(col(codeColName).alias(colName))
}
})
}
}
object DataFrameExt {
implicit class DataFrameMethods(df: DataFrame) {
def flattenSchema(delimiter: String = ".", prefix: String = null): DataFrame = {
df.select(
StructTypeHelpers.flattenSchema(df.schema, delimiter, prefix): _*
)
}
}
}
Solution 7 - Scala
Here is a function that is doing what you want and that can deal with multiple nested columns containing columns with same name, with a prefix:
from pyspark.sql import functions as F
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[F.col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols
for c in nested_df.select(nc+'.*').columns])
return flat_df
Before:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
|-- bar: struct (nullable = true)
| |-- a: float (nullable = true)
| |-- b: float (nullable = true)
| |-- c: integer (nullable = true)
After:
root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo_a: float (nullable = true)
|-- foo_b: float (nullable = true)
|-- foo_c: integer (nullable = true)
|-- bar_a: float (nullable = true)
|-- bar_b: float (nullable = true)
|-- bar_c: integer (nullable = true)
Solution 8 - Scala
To combine David Griffen and V. Samma answers, you could just do this to flatten while avoiding duplicate column names:
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f.dataType match {
case st: StructType => flattenSchema(st, colName)
case _ => Array(col(colName).as(colName.replace(".","_")))
}
})
}
def flattenDataFrame(df:DataFrame): DataFrame = {
df.select(flattenSchema(df.schema):_*)
}
var my_flattened_json_table = flattenDataFrame(my_json_table)
Solution 9 - Scala
A little addition to the code above, if you are working with Nested Struct and Array.
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f match {
case StructField(_, struct:StructType, _, _) => flattenSchema(struct, colName)
case StructField(_, ArrayType(x :StructType, _), _, _) => flattenSchema(x, colName)
case StructField(_, ArrayType(_, _), _, _) => Array(col(colName))
case _ => Array(col(colName))
}
})
}
Solution 10 - Scala
I have been using one liners which result in a flattened schema with 5 columns of bar, baz, x, y, z:
df.select("foo.*", "x", "y", "z")
As for explode
: I typically reserve explode
for flattening a list. For example if you have a column idList
that is a list of Strings, you could do:
df.withColumn("flattenedId", functions.explode(col("idList")))
.drop("idList")
That will result in a new Dataframe with a column named flattenedId
(no longer a list)
Solution 11 - Scala
This is a modification of the solution but it uses tailrec notation
@tailrec
def flattenSchema(
splitter: String,
fields: List[(StructField, String)],
acc: Seq[Column]): Seq[Column] = {
fields match {
case (field, prefix) :: tail if field.dataType.isInstanceOf[StructType] =>
val newPrefix = s"$prefix${field.name}."
val newFields = field.dataType.asInstanceOf[StructType].fields.map((_, newPrefix)).toList
flattenSchema(splitter, tail ++ newFields, acc)
case (field, prefix) :: tail =>
val colName = s"$prefix${field.name}"
val newCol = col(colName).as(colName.replace(".", splitter))
flattenSchema(splitter, tail, acc :+ newCol)
case _ => acc
}
}
def flattenDataFrame(df: DataFrame): DataFrame = {
val fields = df.schema.fields.map((_, ""))
df.select(flattenSchema("__", fields.toList, Seq.empty): _*)
}
Solution 12 - Scala
This is based on @Evan V's solution to deal with more heavily nested Json files. For me the problem with original solution is When there is an ArrayType nested right in another ArrayType, I got an error.
for example if a Json looks like:
{"e":[{"f":[{"g":"h"}]}]}
I will get an error:
"cannot resolve '`e`.`f`['g']' due to data type mismatch: argument 2 requires integral type
To solve this I modified the code a bit, I agree this looks super stupid bust just posting it here so that someone may come up with a nicer solution.
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, T.StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields
def explodeDF(df):
for (name, dtype) in df.dtypes:
if "array" in dtype:
df = df.withColumn(name, F.explode(name))
return df
def df_is_flat(df):
for (_, dtype) in df.dtypes:
if ("array" in dtype) or ("struct" in dtype):
return False
return True
def flatJson(jdf):
keepGoing = True
while(keepGoing):
fields = flatten(jdf.schema)
new_fields = [item.replace(".", "_") for item in fields]
jdf = jdf.select(fields).toDF(*new_fields)
jdf = explodeDF(jdf)
if df_is_flat(jdf):
keepGoing = False
return jdf
Usage:
df = spark.read.json(path_to_json)
flat_df = flatJson(df)
flat_df.show()
+---+---+-----+
| a|e_c|e_f_g|
+---+---+-----+
| b| d| h|
+---+---+-----+
Solution 13 - Scala
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.StructType
import scala.collection.mutable.ListBuffer
val columns=new ListBuffer[String]()
def flattenSchema(schema:StructType,prefix:String=null){
for(i<-schema.fields){
if(i.dataType.isInstanceOf[StructType]) {
val columnPrefix = i.name + "."
flattenSchema(i.dataType.asInstanceOf[StructType], columnPrefix)
}
else {
if(prefix == null)
columns.+=(i.name)
else
columns.+=(prefix+i.name)
}
}
}
Solution 14 - Scala
Combining Evan V's, Avrell and Steco ideas. I am also providing a complete SQL syntax while handling query fields with special characters using '`' in PySpark.
The solution below gives the following,
- Handle Nested JSON Schema.
- Handle same column names across nested columns (We will give alias name of the entire hierarchy separated by underscores).
- Handle Special Characters. (we handle special characters with '', I have not handled consecutive occurences of '' but we can do that as well with appropriate 'sub' replacements)
- Gives us SQL syntax.
- Query Fields are enclosed within '`'.
Code snippet is below,
df=spark.read.json('<JSON FOLDER / FILE PATH>')
df.printSchema()
from pyspark.sql.types import StructType, ArrayType
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
alias_name=name.replace('.','_').replace(' ','_').replace('(','').replace(')','').replace('-','_').replace('&','_').replace(r'(_){2,}',r'\1')
name=name.replace('.','`.`')
field_name = "`" + name + "`" + " AS " + alias_name
fields.append(field_name)
return fields
df.createOrReplaceTempView("to_flatten_df")
query_fields=flatten(df.schema)
def listToString(s):
# initialize an empty string
str1 = ""
# traverse in the string
for ele in s:
str1 = str1 + ele + ','
# return string
return str1
spark.sql("SELECT " + listToString(query_fields)[:-1] + " FROM to_flatten_df" ).show()