Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

filter spark dataframe with row field that is an array of strings

Using Spark 1.5 and Scala 2.10.6

I'm trying to filter a dataframe via a field "tags" that is an array of strings. Looking for all rows that have the tag 'private'.

val report = df.select("*")
  .where(df("tags").contains("private"))

getting:

Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'Contains(tags, private)' due to data type mismatch: argument 1 requires string type, however, 'tags' is of array type.;

Is the filter method better suited?

UPDATED:

the data is coming from cassandra adapter but a minimal example that shows what I'm trying to do and also gets the above error is:

  def testData (sc: SparkContext): DataFrame = {
    val stringRDD = sc.parallelize(Seq("""
      { "name": "ed",
        "tags": ["red", "private"]
      }""",
      """{ "name": "fred",
        "tags": ["public", "blue"]
      }""")
    )
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    import sqlContext.implicits._
    sqlContext.read.json(stringRDD)
  }
  def run(sc: SparkContext) {
    val df1 = testData(sc)
    df1.show()
    val report = df1.select("*")
      .where(df1("tags").contains("private"))
    report.show()
  }

UPDATED: the tags array can be any length and the 'private' tag can be in any position

UPDATED: one solution that works: UDF

val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")}
val report = df1.filter(filterPriv(df1("tags")))
like image 279
navicore Avatar asked Jan 17 '16 00:01

navicore


People also ask

How do I filter rows in Spark DataFrame?

Spark filter() or where() function is used to filter the rows from DataFrame or Dataset based on the given one or multiple conditions or SQL expression. You can use where() operator instead of the filter if you are coming from SQL background. Both these functions operate exactly the same.

How do I filter string values in PySpark?

In Spark & PySpark, contains() function is used to match a column value contains in a literal string (matches on part of the string), this is mostly used to filter rows on DataFrame.

How do you filter out rows in PySpark?

PySpark filter() function is used to filter the rows from RDD/DataFrame based on the given condition or SQL expression, you can also use where() clause instead of the filter() if you are coming from an SQL background, both these functions operate exactly the same.

How do I select rows from Spark DataFrame?

Selecting rows using the filter() function The first option you have when it comes to filtering DataFrame rows is pyspark. sql. DataFrame. filter() function that performs filtering based on the specified conditions.


1 Answers

I think if you use where(array_contains(...)) it will work. Here's my result:

scala> import org.apache.spark.SparkContext
import org.apache.spark.SparkContext

scala> import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.DataFrame

scala> def testData (sc: SparkContext): DataFrame = {
     |     val stringRDD = sc.parallelize(Seq
     |      ("""{ "name": "ned", "tags": ["blue", "big", "private"] }""",
     |       """{ "name": "albert", "tags": ["private", "lumpy"] }""",
     |       """{ "name": "zed", "tags": ["big", "private", "square"] }""",
     |       """{ "name": "jed", "tags": ["green", "small", "round"] }""",
     |       """{ "name": "ed", "tags": ["red", "private"] }""",
     |       """{ "name": "fred", "tags": ["public", "blue"] }"""))
     |     val sqlContext = new org.apache.spark.sql.SQLContext(sc)
     |     import sqlContext.implicits._
     |     sqlContext.read.json(stringRDD)
     |   }
testData: (sc: org.apache.spark.SparkContext)org.apache.spark.sql.DataFrame

scala>   
     | val df = testData (sc)
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> val report = df.select ("*").where (array_contains (df("tags"), "private"))
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> report.show
+------+--------------------+
|  name|                tags|
+------+--------------------+
|   ned|[blue, big, private]|
|albert|    [private, lumpy]|
|   zed|[big, private, sq...|
|    ed|      [red, private]|
+------+--------------------+

Note that it works if you write where(array_contains(df("tags"), "private")), but if you write where(df("tags").array_contains("private")) (more directly analogous to what you wrote originally) it fails with array_contains is not a member of org.apache.spark.sql.Column. Looking at the source code for Column, I see there's some stuff to handle contains (constructing a Contains instance for that) but not array_contains. Maybe that's an oversight.

like image 175
Robert Dodier Avatar answered Sep 29 '22 12:09

Robert Dodier