Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cannot pickle Tensorflow object in Python - TypeError: can't pickle _thread._local objects

I want to pickle the history object after running a keras fit on tensorflow. But I am getting an error.

import gzip
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import keras


with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, test_set = pickle.load(f, encoding='latin1')

X_train = np.asarray(train_set[0])
y_train = np.asarray(train_set[1])

X_test = np.asarray(test_set[0])
y_test = np.asarray(test_set[1])

X_valid, X_train = X_train[:5000]/255.0, X_train[5000:]/255.0
y_valid, y_train = y_train[:5000], y_train[5000:]

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.summary()

model.compile(loss='sparse_categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=1,
                    validation_data =(X_valid, y_valid))

if not os.path.isdir('models'):
    os.mkdir('models')

model.save('models/basic.h5')
with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history, f)

It gives me the following error:

Traceback (most recent call last):
  File "main.py", line 69, in <module>
    pickle.dump(history, f)
TypeError: can't pickle _thread._local objects

PS: To get the code to run, download the fashion_mnist data: https://s3.amazonaws.com/img-datasets/mnist.pkl.g

like image 488
Al-Baraa El-Hag Avatar asked Dec 13 '19 16:12

Al-Baraa El-Hag


People also ask

Can’t pickle local objects in Python?

Pickling is not allowed in different languages. So pickling and unpickling are only possible in the same versions of the python file. Many of the time we will face an error as an Attribute error. It shows like can’t pickle local objects. Let us see why this error occurs and how to solve that.

What is pickling in Python?

Pickling is the process of converting an object into a byte stream to store it in either a file or database. These pickled objects are useful to recreate the python original objects.

Is it possible to pickle a keras model with TensorFlow?

I have added these two methods to show that with get_model_prev () ie. model with normal Keras, its working fine and able to pickle if you use previous version of tensorflow and use keras to build layers. Sorry, something went wrong. I could reproduce the issue with changes mentioned in last comment. Please see the gist. Thanks!

What is the latest version of TensorFlow in Python?

TensorFlow version (use command below): v2.4.0-0-g582c8d236cb 2.4.0 Python version: 3.7.9 Running a simple training process with MultiWorkerMirroredStrategy fails with TypeError: can't pickle _thread.lock objects.


1 Answers

As Karl suggested, the history object cannot be pickled. But it's dictionary can:

with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history.history, f)
like image 96
Al-Baraa El-Hag Avatar answered Sep 22 '22 12:09

Al-Baraa El-Hag