Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loading a model inside while loop of Tensorflow

I am trying to use @tf.function(jit_compile=True) to create a TF graph with a while loop; below is its pseudocode. I'm not able to provide a functioning code since it contains a lot of dependencies.

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3
  

The above code1 results in a memory explosion because I am calling the model inside the while loop. When I load the model outside both of the while loops, I get the error RuntimeError: Cannot get session inside Tensorflow graph function. What is the best way to prevent memory explosion?

Edit 1: The inputs are tensors. The problem here is that I need to pass a large batch at once. For this, I made a while loop and thought that the while loop would work in parallel (keras model() can only process 32 samples at once). I am not sure why keras model does not have batch size as an input. In the above code is it preferable to directly load the weights and all the values and do the manual computation to get the outputs? In the case of code 2, will each mode have a graph because it is called inside while loop?

Edit 2: function 3 has gradient computation of out3 with respect to inputs.

like image 961
newbie Avatar asked Feb 28 '26 04:02

newbie


1 Answers

2 possible solutions:

According to the tensorflow documentation, your problem may be stored tensors used in back propogation.

Try:

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while(swap_memory=True)
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while(swap_memory=True)
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3

Or at the end of your loops, try calling, K.clear_session() to reset the states.

from tensorflow.keras import backend as K

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3
like image 69
Djinn Avatar answered Mar 03 '26 08:03

Djinn



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!