Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

IDE breakpoint in TensorFlow Dataset API mapped py_function?

I'm using the Tensorflow Dataset API to prepare my data for input into my network. During this process, I have some custom Python functions which are mapped to the dataset using tf.py_function. I want to be able to debug the data going into these functions and what happens to that data inside these functions. When a py_function is called, this calls back to the main Python process (according to this answer). Since this function is in Python, and in the main process, I would expect a regular IDE breakpoint to be able stop in this process. However, this doesn't seem to be the case (example below where the breakpoint does not halt execution). Is there a way to drop into a breakpoint within a py_function used by the Dataset map?

Example where the breakpoint does not halt execution

import tensorflow as tf

def add_ten(example, label):
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
like image 366
golmschenk Avatar asked Dec 10 '19 20:12

golmschenk


People also ask

What is TensorFlow Data API?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is TF data Autotune?

tf. data builds a performance model of the input pipeline and runs an optimization algorithm to find a good allocation of its CPU budget across all parameters specified as AUTOTUNE .


1 Answers

Tensorflow 2.0 implementation of tf.data.Dataset opens a C threads for each call without notifying your debugger. Use pydevd's to manually set a tracing function that will connect to your default debugger server and start feeding it the debug data.

import pydevd
pydevd.settrace()

Example with your code:

import tensorflow as tf
import pydevd

def add_ten(example, label):
    pydevd.settrace(suspend=False)
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))

Note: If you are using IDE which already bundles pydevd (such as PyDev or PyCharm) you do not have to install pydevd separately, it will picked up during the debug session.

like image 143
Daniel Braun Avatar answered Oct 21 '22 19:10

Daniel Braun