Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert Estimator to TPUEstimator

Is it possible to convert an Estimator to a TPUEstimator in TensorFlow without significant effort in rewriting its functions? I have a model in Estimator form that works nicely on a CPU, but I don't know a convenient way to convert it to a TPUEstimator without having to rewrite the model_fn and input_fn.

The reason this requires significant work to do manually is that I am using Keras to create my model, and then the following helper function to create the Estimator:

   my_keras_model.compile(
                optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                loss='categorical_crossentropy',
                metric='accuracy')
   estimator = tf.keras.estimator.model_to_estimator(keras_model=my_keras_model)

It would be great if I could do something like estimator.to_TPU_estimator() or something like that -- perhaps someone knows of a solution?

like image 498
atp Avatar asked Mar 07 '23 20:03

atp


1 Answers

There can't be such a function, because model_fn specification is different in two estimators. Some differences are pretty deep, such as this one (from TPU tutorial):

When training on a cloud TPU you must wrap the optimizer in a tf.contrib.tpu.CrossShardOptimizer, which uses an allreduce to aggregate gradients and broadcast the result to each shard (each TPU core).

And it means patching the internals of keras optimizer and update ops.

The recommended way to is to have different model_fn wrappers for GPU and TPU model and it seems the fastest way for you. In your case, it means rewriting keras model_to_estimator function for TPU estimator.


The first and simplest approximation is this:

def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  keras_weights = keras_model.get_weights()
  keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
  est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
  _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
  return est

Here, _save_first_checkpoint call is actually optional, but if you'd like to keep it, import this function from tensorflow.python.keras._impl.keras.estimator.


The real work happens in _create_keras_tpu_model_fn function, which replaces _create_keras_model_fn. The changes are:

  • the internal tensorflow optimizer must be wrapped with CrossShardOptimizer as mentioned earlier, and

  • the inner function must return TPUEstimatorSpec.

It is possible that few more lines must be patched as well, but it looks ok to me. A complete version is below:

from tensorflow.python.keras._impl.keras.estimator import _save_first_checkpoint, _clone_and_build_model

def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  keras_weights = keras_model.get_weights()
  keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
  est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
  _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
  return est


def _create_keras_tpu_model_fn(keras_model, custom_objects=None):

  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    model = _clone_and_build_model(mode, keras_model, custom_objects, features,
                                   labels)
    predictions = dict(zip(model.output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not tf.estimator.ModeKeys.PREDICT:
      model.optimizer.optimizer = tf.contrib.tpu.CrossShardOptimizer(model.optimizer.optimizer)

      model._make_train_function()  # pylint: disable=protected-access
      loss = model.total_loss

      if model.metrics:
        eval_metric_ops = {}
        # When each metric maps to an output
        if isinstance(model.metrics, dict):
          for i, output_name in enumerate(model.metrics.keys()):
            metric_name = model.metrics[output_name]
            if callable(metric_name):
              metric_name = metric_name.__name__
            # When some outputs use the same metric
            if list(model.metrics.values()).count(metric_name) > 1:
              metric_name += '_' + output_name
            eval_metric_ops[metric_name] = tf.metrics.mean(
                model.metrics_tensors[i - len(model.metrics)])
        else:
          for i, metric_name in enumerate(model.metrics):
            if callable(metric_name):
              metric_name = metric_name.__name__
            eval_metric_ops[metric_name] = tf.metrics.mean(
                model.metrics_tensors[i])

    if mode is tf.estimator.ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops)

  return model_fn
like image 183
Maxim Avatar answered Mar 15 '23 08:03

Maxim