Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using sample_weights with fit_generator()

Inside an autoregressive continuous problem, when the zeros take too much place, it is possible to treat the situation as a zero-inflated problem (i.e. ZIB). In other words, instead of working to fit f(x), we want to fit g(x)*f(x) where f(x) is the function we want to approximate, i.e. y, and g(x) is a function which output a value between 0 and 1 depending if a value is zero or non-zero.

Currently, I have two models. One model which gives me g(x) and another model which fits g(x)*f(x).

The first model gives me a set of weights. This is where I need your help. I can use the sample_weights arguments with model.fit(). As I work with tremendous amount of data, then I need to work with model.fit_generator(). However, fit_generator() does not have the argument sample_weights.

Is there a work around to work with sample_weights inside fit_generator()? Otherwise, how can I fit g(x)*f(x) knowing that I have already a trained model for g(x)?

like image 886
user1050421 Avatar asked Nov 17 '18 19:11

user1050421


People also ask

Is Model Fit_generator deprecated?

Update July 2021: For TensorFlow 2.2+ users, just use the . fit method for your projects. The . fit_generator method will be deprecated in future releases of TensorFlow as the .

What is the difference between Fit_generator and fit?

You pass your whole dataset at once in fit method. Also, use it if you can load whole data into your memory (small dataset). In fit_generator() , you don't pass the x and y directly, instead they come from a generator.

How does keras Fit_generator work?

fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model. For the number of epochs specified(10 in our case) the process is repeated.


1 Answers

You can provide sample weights as the third element of the tuple returned by the generator. From Keras documentation on fit_generator:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

  • a tuple (inputs, targets)
  • a tuple (inputs, targets, sample_weights).

Update: Here is a rough sketch of a generator that returns the input samples and targets as well as the sample weights obtained from model g(x):

def gen(args):
    while True:
        for i in range(num_batches):
            # get the i-th batch data
            inputs = ...
            targets = ...
            
            # get the sample weights
            weights = g.predict(inputs)
            
            yield inputs, targets, weights
            
            
model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)
    
    
like image 71
today Avatar answered Nov 24 '22 11:11

today