Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Updating Unrolled GAN to TF2

I am trying to implement the Unrolled GAN model as described here, with example code. However, it was implemented using TF1, and I have been doing my best to update it but I am relatively new to python and TF (only been using it for the past ~6 months).

The line(s) that I cannot seem to make work (for the moment, there may be more) is this one:

gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

These both return empty lists, and I cannot see what I am missing. Even without specifying a scope, the get_collection() returns []. Earlier, we define both generator and discriminator as scopes like so:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d

Is there a problem with the definition of the scope?

Here is my updated code in full, in case there is maybe something I missed elsewhere:

%pylab inline
from collections import OrderedDict
import tensorflow as tf
import tensorflow_probability as tfp
ds = tfp.distributions
# slim = tf.contrib.slim
import tf_slim as slim

from keras.optimizers import Adam

try:
    from moviepy.video.io.bindings import mplfig_to_npimage
    import moviepy.editor as mpy
    generate_movie = True
except:
    print("Warning: moviepy not found.")
    generate_movie = False


def remove_original_op_attributes(graph):
    """Remove _original_op attribute from all operations in a graph."""
    for op in graph.get_operations():
        op._original_op = None
        
def graph_replace(*args, **kwargs):
    """Monkey patch graph_replace so that it works with TF 1.0"""
    remove_original_op_attributes(tf.get_default_graph())
    return _graph_replace(*args, **kwargs)




def extract_update_dict(update_ops):
    """Extract variables and their new values from Assign and AssignAdd ops.
    
    Args:
        update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()

    Returns:
        dict mapping from variable values to their updated value
    """
    name_to_var = {v.name: v for v in tf.compat.v1.global_variables()}
    updates = OrderedDict()
    for update in update_ops:
        var_name = update.op.inputs[0].name
        var = name_to_var[var_name]
        value = update.op.inputs[1]
        if update.op.type == 'Assign':
            updates[var.value()] = value
        elif update.op.type == 'AssignAdd':
            updates[var.value()] = var + value
        else:
            raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
    return updates



def sample_mog(batch_size, n_mixture=8, std=0.01, radius=1.0):
    thetas = np.linspace(0, 2 * np.pi, n_mixture)
    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
    cat = ds.Categorical(tf.zeros(n_mixture))
    comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
    data = ds.Mixture(cat, comps)
    return data.sample(batch_size)



def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.compat.v1.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d



params = dict(
    batch_size=512,
    disc_learning_rate=1e-4,
    gen_learning_rate=1e-3,
    beta1=0.5,
    epsilon=1e-8,
    max_iter=25000,
    viz_every=5000,
    z_dim=256,
    x_dim=2,
    unrolling_steps=5,
)


tf.compat.v1.reset_default_graph()

data = sample_mog(params['batch_size'])

noise = ds.Normal(tf.zeros(params['z_dim']), 
                  tf.ones(params['z_dim'])).sample(params['batch_size'])
# Construct generator and discriminator nets
# with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=1.4)): ## old
with slim.arg_scope([slim.fully_connected], weights_initializer=tf.keras.initializers.Orthogonal(gain=1.4)):
    samples = generator(noise, output_dim=params['x_dim'])
    real_score = discriminator(data)
    fake_score = discriminator(samples, reuse=True)
    
# Saddle objective    
loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(real_score, dtype=tf.float32), labels=tf.cast(tf.ones_like(real_score), dtype=tf.float32)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.cast(fake_score, dtype=tf.float32), labels=tf.cast(tf.zeros_like(fake_score), dtype=tf.float32)))

gen_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "generator")
disc_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

# Vanilla discriminator update
d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
# updates = d_opt.get_updates(disc_vars, [], loss) ## old
updates = d_opt.get_updates(loss, [])
d_train_op = tf.group(*updates, name="d_train_op")

### I HAVE NOT UPDATED BEYOND THIS POINT ###

# Unroll optimization of the discrimiantor
if params['unrolling_steps'] > 0:
    # Get dictionary mapping from variables to their update value after one optimization step
    update_dict = extract_update_dict(updates)
    cur_update_dict = update_dict
    for i in xrange(params['unrolling_steps'] - 1):
        # Compute variable updates given the previous iteration's updated variable
        cur_update_dict = graph_replace(update_dict, cur_update_dict)
    # Final unrolled loss uses the parameters at the last time step
    unrolled_loss = graph_replace(loss, cur_update_dict)
else:
    unrolled_loss = loss

# Optimize the generator on the unrolled loss
g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)


sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

like image 363
Whitehot Avatar asked Jun 09 '21 10:06

Whitehot


1 Answers

The implementation of get_collection:

def get_collection(key, scope=None):
  """Wrapper for `Graph.get_collection()` using the default graph.

  See `tf.Graph.get_collection`
  for more details.

  Args:
    key: The key for the collection. For example, the `GraphKeys` class contains
      many standard names for collections.
    scope: (Optional.) If supplied, the resulting list is filtered to include
      only items whose `name` attribute matches using `re.match`. Items without
      a `name` attribute are never returned if a scope is supplied and the
      choice or `re.match` means that a `scope` without special tokens filters
      by prefix.

  Returns:
    The list of values in the collection with the given `name`, or
    an empty list if no value has been added to that collection. The
    list contains the values in the order under which they were
    collected.

  @compatibility(eager)
  Collections are not supported when eager execution is enabled.
  @end_compatibility
  """
  return get_default_graph().get_collection(key, scope)

It looks like in this code, key and scope arguments are swapped. If you provide "generator" or "discriminator" as the key with no scope i.e;

gen_vars = tf.compat.v1.get_collection("generator")
disc_vars = tf.compat.v1.get_collection("discriminator")

You should get results (I was able to reproduce locally with Tensorflow 2.2.0). The only issue I could not quite identify is, when providing scope, the function returns an empty list again, regardless of the scope value you provide. For example, tf.compat.v1.GLOBAL_VARIABLES should return everything, but that is not the case:

gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) # returns []
gen_vars = tf.compat.v1.get_default_graph().get_collection('generator', tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) # returns []
disc_vars = tf.compat.v1.get_collection('generator') # returns a list of tensors

Update

It looks like even creating the variables in the context manager doesn't add them to the graph collection. I had to call tf.compat.v1.add_to_collection('generator', x) and tf.compat.v1.add_to_collection('discriminator', log_d) in the respective functions to get those results.

Update #2

I searched around and it doesn't appear there's a context manager which enables you to add variables declared within it to a Tensorflow collection. For the sake of completeness of this answer though, I have implemented one:

from contextlib import contextmanager

@contextmanager
def collection_scope(collection_name):
    import inspect
    from tensorflow.python.framework.ops import EagerTensor
    collection = tf.compat.v1.get_collection_ref(collection_name)
    yield
    # this is a bit of a hack, but it works...
    f = inspect.currentframe().f_back.f_back
    # only take variables which were declared within the context manager
    tf_variables = set([val.ref() for val in f.f_locals.values() if isinstance(val, EagerTensor)]) - \
                   set([val.ref() for val in f.f_back.f_locals.values() if isinstance(val, EagerTensor)])
    collection.extend(tf_variables)

You can then drop this in your functions in place of the variable scope (tf.compat.v1.variable_scope) context managers. For example, instead of:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.compat.v1.variable_scope('generator'):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

Do the following:

def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with collection_scope('generator'):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

With this change, all tensors declared within the scope of the context manager will be added to the collection "generator" - tf.compat.v1.get_collection('generator') will return the correct list of tensors.

like image 105
danielcahall Avatar answered Oct 21 '22 22:10

danielcahall