Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does Apache PySpark top() fail when the RDD contains a user defined class?

I'm prototyping some code using Apache Spark's PySpark on my local machine, via iPython Notebook. I've written some code that seems to work fine, but when I make a simple change to it, it breaks.

The first code block below works. The second block fails with the given error. Really appreciate any help. I suspect the error is something to do with serializing Python objects. The error says it cant Pickle TestClass. I cant find information on how to make my class pickle-able. The documentation says "Generally you can pickle any object if you can pickle every attribute of that object. Classes, functions, and methods cannot be pickled -- if you pickle an object, the object's class is not pickled, just a string that identifies what class it belongs to. This works fine for most pickles (but note the discussion about long-term storage of pickles).". I don't understand this, as I've tried replacing my TestClass with a datetime class and things seem to work just fine.

Anyway, the code:

# ----------- This code works -----------------------------
class TestClass(object):
    def __init__(self):
        self.teststr = 'Hello'
    def __str__(self):
        return self.teststr
    def __repr__(self):
        return self.teststr
    def test(self):
        return 'test: {0}'.format(self.teststr)

#load multiple text files into list of RDDs, concatenate them, then remove headers
trip_rdd  = trip_rdds[0]
for rdd in trip_rdds[1:]:
    trip_rdd = trip_rdd.union(rdd)

#filter out header rows from result
trip_rdd = trip_rdd.filter(lambda r: r != header)

#split the line, then convert each element to a dictionary
trip_rdd = trip_rdd.map(lambda r: r.split(','))
trip_rdd = trip_rdd.map(lambda r, k = header_keys: dict(zip(k, r)))
trip_rdd = trip_rdd.map(convert_trip_dict)
#trip_rdd = trip_rdd.map(lambda d, ps = g_nyproj_str: Trip(d, ps))

#originally I map the given dictionaries to a 'Trip' class I defined with various bells and whistles. 
#I've simplified to using TestClass above and still seem to get the same error

trip_rdd = trip_rdd.map(lambda t: TestClass())
trip_rdd = trip_rdd.map(lambda t: t.test()) #(1) Watch this row

print trip_rdd.count()
temp = trip_rdd.top(3)
print temp
print '...done'

The above code returns the following:

347098

['test: Hello', 'test: Hello', 'test: Hello']

...done

But when I delete the row marked "(1) watch this row" - the last map line - and re-run I get the following error instead. Its long, so I'm going to wrap up my question here, before posting the output. Again, I'd really appreciate help with this.

Thanks in advance!

# ----------- This code FAILS -----------------------------
class TestClass(object):
    def __init__(self):
        self.teststr = 'Hello'
    def __str__(self):
        return self.teststr
    def __repr__(self):
        return self.teststr
    def test(self):
        return 'test: {0}'.format(self.teststr)

#load multiple text files into list of RDDs, concatenate them, then remove headers
trip_rdds = [sc.textFile(f) for f in trip_files]
trip_rdd  = trip_rdds[0]
for rdd in trip_rdds[1:]:
    trip_rdd = trip_rdd.union(rdd)

#filter out header rows from result
trip_rdd = trip_rdd.filter(lambda r: r != header)

#split the line, then convert each element to a dictionary
trip_rdd = trip_rdd.map(lambda r: r.split(','))
trip_rdd = trip_rdd.map(lambda r, k = header_keys: dict(zip(k, r)))
trip_rdd = trip_rdd.map(convert_trip_dict)
#trip_rdd = trip_rdd.map(lambda d, ps = g_nyproj_str: Trip(d, ps))

#originally I map the given dictionaries to a 'Trip' class I defined with various bells and whistles. 
#I've simplified to using TestClass above and still seem to get the same error

trip_rdd = trip_rdd.map(lambda t: TestClass())
trip_rdd = trip_rdd.map(lambda t: t.test()) #(1) Watch this row

print trip_rdd.count()
temp = trip_rdd.top(3)
print temp
print '...done'

Output: 347098

*---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-76-6550318a5d5b> in <module>()
     29 #count them
     30 print trip_rdd.count()
---> 31 temp = trip_rdd.top(3)
     32 print temp
     33 print '...done'

C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\rdd.pyc in top(self, num, key)
   1043             return heapq.nlargest(num, a + b, key=key)
   1044 
-> 1045         return self.mapPartitions(topIterator).reduce(merge)
   1046 
   1047     def takeOrdered(self, num, key=None):

C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\rdd.pyc in reduce(self, f)
    713             yield reduce(f, iterator, initial)
    714 
--> 715         vals = self.mapPartitions(func).collect()
    716         if vals:
    717             return reduce(f, vals)

C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\rdd.pyc in collect(self)
    674         """
    675         with SCCallSiteSync(self.context) as css:
--> 676             bytesInJava = self._jrdd.collect().iterator()
    677         return list(self._collect_iterator_through_file(bytesInJava))
    678 

C:\Programs\Coding\Languages\Python\Anaconda_32bit\Conda\lib\site-packages\py4j-0.8.2.1-py2.7.egg\py4j\java_gateway.pyc in __call__(self, *args)
    536         answer = self.gateway_client.send_command(command)
    537         return_value = get_return_value(answer, self.gateway_client,
--> 538                 self.target_id, self.name)
    539 
    540         for temp_arg in temp_args:

C:\Programs\Coding\Languages\Python\Anaconda_32bit\Conda\lib\site-packages\py4j-0.8.2.1-py2.7.egg\py4j\protocol.pyc in get_return_value(answer, gateway_client, target_id, name)
    298                 raise Py4JJavaError(
    299                     'An error occurred while calling {0}{1}{2}.\n'.
--> 300                     format(target_id, '.', name), value)
    301             else:
    302                 raise Py4JError(

Py4JJavaError: An error occurred while calling o463.collect.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 49.0 failed 1 times, most recent failure: Lost task 1.0 in stage 49.0 (TID 99, localhost): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\worker.py", line 107, in main
    process()
  File "C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\worker.py", line 98, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\serializers.py", line 231, in dump_stream
    bytes = self.serializer.dumps(vs)
  File "C:\Programs\Apache\Spark\spark-1.2.0-bin-hadoop2.4\python\pyspark\serializers.py", line 393, in dumps
    return cPickle.dumps(obj, 2)
PicklingError: Can't pickle <class '__main__.TestClass'>: attribute lookup __main__.TestClass failed

    at org.apache.spark.api.python.PythonRDD$$anon$1.read(PythonRDD.scala:137)
    at org.apache.spark.api.python.PythonRDD$$anon$1.<init>(PythonRDD.scala:174)
    at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:96)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:230)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)
    at org.apache.spark.scheduler.Task.run(Task.scala:56)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:196)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1214)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1203)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1202)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1202)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:696)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:696)
    at scala.Option.foreach(Option.scala:236)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:696)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessActor$$anonfun$receive$2.applyOrElse(DAGScheduler.scala:1420)
    at akka.actor.Actor$class.aroundReceive(Actor.scala:465)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessActor.aroundReceive(DAGScheduler.scala:1375)
    at akka.actor.ActorCell.receiveMessage(ActorCell.scala:516)
    at akka.actor.ActorCell.invoke(ActorCell.scala:487)
    at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:238)
    at akka.dispatch.Mailbox.run(Mailbox.scala:220)
    at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(AbstractDispatcher.scala:393)
    at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260)
    at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339)
    at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)
    at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)*
like image 522
user3279453 Avatar asked Nov 09 '22 17:11

user3279453


1 Answers

Turns out you have to define your class in its own module, not in the main body of the code. If you do that and then import the module, pickle is able to pickle and unpickle the object successfully. The class then works with Spark as you'd expect.

like image 69
user3279453 Avatar answered Nov 14 '22 22:11

user3279453