I'm using the Tensorflow Dataset API to prepare my data for input into my network. During this process, I have some custom Python functions which are mapped to the dataset using tf.py_function
. I want to be able to debug the data going into these functions and what happens to that data inside these functions. When a py_function
is called, this calls back to the main Python process (according to this answer). Since this function is in Python, and in the main process, I would expect a regular IDE breakpoint to be able stop in this process. However, this doesn't seem to be the case (example below where the breakpoint does not halt execution). Is there a way to drop into a breakpoint within a py_function
used by the Dataset map
?
Example where the breakpoint does not halt execution
import tensorflow as tf
def add_ten(example, label):
example_plus_ten = example + 10 # Breakpoint here.
return example_plus_ten, label
examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels = [ 0, 0, 1, 1, 1, 1, 0, 0]
examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.
tf. data builds a performance model of the input pipeline and runs an optimization algorithm to find a good allocation of its CPU budget across all parameters specified as AUTOTUNE .
Tensorflow 2.0 implementation of tf.data.Dataset opens a C threads for each call without notifying your debugger.
Use pydevd
's to manually set a tracing function that will connect to your default debugger server and start feeding it the debug data.
import pydevd
pydevd.settrace()
Example with your code:
import tensorflow as tf
import pydevd
def add_ten(example, label):
pydevd.settrace(suspend=False)
example_plus_ten = example + 10 # Breakpoint here.
return example_plus_ten, label
examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels = [ 0, 0, 1, 1, 1, 1, 0, 0]
examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
Note: If you are using IDE which already bundles pydevd (such as PyDev or PyCharm) you do not have to install pydevd
separately, it will picked up during the debug session.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With