Filtering pyspark dataframe if text column includes words in specified list

I've seen questions posted here that are similar to mine, but I'm still getting errors in my code when trying some accepted answers. I have a dataframe with three columns--created _at, text, and words (which is just tokenized version of text). See below:

Now, I have a list of companies ['Starbucks', 'Nvidia', 'IBM', 'Dell'], and I only want to keep the rows where the text includes those words above.

I've tried a few things, but with no success:

small_DF.filter(lambda x: any(word in x.text for word in test_list))

Returns : TypeError: condition should be string or Column

I tried creating a function and using foreach():

def filters(line):
   return(any(word in line for word in test_list))
df = df.foreach(filters)

That turns df into 'Nonetype'

And the last one I tried:

df = df.filter((col("text").isin(test_list))

This returns an empty dataframe, which is nice as I get no error, but obviously not what I want.

2 Answers

Your .filter returns an error because it is the sql filter function (expecting a BooleanType() column) on dataframes not the filter function on RDDs. If you want to use the RDD one, just add .rdd:

small_DF.rdd.filter(lambda x: any(word in x.text for word in test_list))

You don't have to use a UDF, you can use regular expressions in pyspark with .rlike on your column "text":

from pyspark.sql import HiveContext
hc = HiveContext(sc)
import pyspark.sql.functions as psf

words = [x.lower() for x in ['starbucks', 'Nvidia', 'IBM', 'Dell']]
data = [['i love Starbucks'],['dell laptops rocks'],['help me I am stuck!']]
df = hc.createDataFrame(data).toDF('text')
I think filter isnt working becuase it expects a boolean output from lambda function and isin just compares with column. You are trying to compare list of words to list of words. Here is something that I tried can give you some direction -

# prepare some test data ==> 

words = [x.lower() for x in ['starbucks', 'Nvidia', 'IBM', 'Dell']]
data = [['i love Starbucks'],['dell laptops rocks'],['help me I am stuck!']]
df = spark.createDataFrame(data).toDF('text')

from pyspark.sql.types import *

def intersect(row):
    # convert each word in lowecase
    row = [x.lower() for x in row.split()]
    return True if set(row).intersection(set(words)) else False

filterUDF = udf(intersect,BooleanType())

output :

|              text|
|  i love Starbucks|
|dell laptops rocks|
