Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Checkpoint Custom Map

Tags:

I am subclassing a keras model with custom layers. Each layer wraps a dictionary of parameters that is used when generating they layers. It seems these param dictionaries are not set before the training checkpoint is made in Tensorflow, they are set after, which causes an error. I am not sure how to fix this, as the ValueError being raised also gives outdated information (tf.contrib no longer exists).

ValueError: Unable to save the object {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary was modified outside the wrapper (its final value was {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True}, its value when a checkpoint dependency was added was None), which breaks restoration on object creation.

If you don't need this dictionary checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.

Here's an example of the Layer that is throwing this issue:

class RecurrentConfig(BaseLayer):
    '''Basic configurable recurrent layer'''
    def __init__(self, params: Dict[Any, Any], mode: ModeKeys, layer_name: str = '', **kwargs):
        self.layer_name = layer_name
        self.cell_name = params.pop('cell', 'GRU')
        self.num_layers = params.pop('num_layers', 1)
        kwargs['name'] = layer_name
        super().__init__(params, mode, **kwargs)
        if layer_name == '':
            self.layer_name = self.cell_name
        self.layers: List[layers.Layer] = stack_layers(self.params,
                                                       self.num_layers,
                                                       self.cell_name)

    def call(self, inputs: np.ndarray) -> layers.Layer:
        '''This function is a sequential/functional call to this layers logic
        Args:
            inputs: Array to be processed within this layer
        Returns:
            inputs processed through this layer'''
        processed = inputs
        for layer in self.layers:
            processed = layer(processed)
        return processed

    @staticmethod
    def default_params() -> Dict[Any, Any]:
        return{
            'units': 32,
            'recurrent_initializer': 'glorot_uniform',
            'dropout': 0,
            'recurrent_dropout': 0,
            'activation': 'tanh',
            'return_sequences': True
        }

BaseLayer.py

'''Basic ABC for a keras style layer'''

from typing import Dict, Any

from tensorflow.keras import layers
from mosaix_py.mosaix_learn.configurable import Configurable

class BaseLayer(Configurable, layers.Layer):
    '''Base configurable Keras layer'''
    def get_config(self) -> Dict[str, Any]:
        '''Return configuration dictionary as part of keras serialization'''
        config = super().get_config()
        config.update(self.params)
        return config

    @staticmethod
    def default_params() -> Dict[Any, Any]:
        raise NotImplementedError('Layer does not implement default params')
like image 750
Jacob B Avatar asked Jan 14 '20 00:01

Jacob B


2 Answers

The issue I was facing was that I was popping items form a dictionary passed into a layers.Layer

    self.cell_name = params.pop('cell', 'GRU')
    self.num_layers = params.pop('num_layers', 1)

When passing a dictionary into a layer it must remain unchanged as it is tracked.

My solution was to further abstract away parameter parsing and pass in a finalized dictionary.

like image 151
Jacob B Avatar answered Oct 11 '22 17:10

Jacob B


Your RecurrentConfig object should inherit from tf.keras.layers.Layer instead of BaseLayer. The TF documentation on checkpoints/delayed restorations covers why:

Layer objects in TensorFlow may delay the creation of variables to their first call, when input shapes are available. For example the shape of a Dense layer's kernel depends on both the layer's input and output shapes, and so the output shape required as a constructor argument is not enough information to create the variable on its own. Since calling a Layer also reads the variable's value, a restore must happen between the variable's creation and its first use.

To support this idiom, tf.train.Checkpoint queues restores which don't yet have a matching variable.

like image 42
charlesreid1 Avatar answered Oct 11 '22 15:10

charlesreid1