Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does train_on_batch() do in keras model?

I saw a sample of code (too big to paste here) where the author used model.train_on_batch(in, out) instead of model.fit(in, out). The official documentation of Keras says:

Single gradient update over one batch of samples.

But I don't get it. Is it the same as fit(), but instead of doing many feed-forward and backprop steps, it does it once? Or am I wrong?

like image 602
CerushDope Avatar asked Jan 31 '18 19:01

CerushDope


People also ask

What is model train_on_batch?

train_on_batch allows you to expressly update weights based on a collection of samples you provide, without regard to any fixed batch size. You would use this in cases when that is what you want: to train on an explicit collection of samples.

What does model train_on_batch return?

So in the end what train_on_batch returns is the loss, output one mse, and output two mse. That is why you get three values. Follow this answer to receive notifications. answered Mar 17, 2021 at 15:09.

How do I check validation accuracy in keras?

You can do this by setting the validation_split argument on the fit() function to a percentage of the size of your training dataset. For example, a reasonable value might be 0.2 or 0.33 for 20% or 33% of your training data held back for validation.

What is default batch size in model fit?

Here we are training our network for 10 epochs along with the default batch size of 32. For small and less complex datasets it is recommended to use keras.


2 Answers

Yes, train_on_batch trains using a single batch only and once.

While fit trains many batches for many epochs. (Each batch causes an update in weights).

The idea of using train_on_batch is probably to do more things yourself between each batch.

like image 183
Daniel Möller Avatar answered Sep 22 '22 13:09

Daniel Möller


It is used when we want to understand and do some custom changes after each batch training.

A more precide use case is with the GANs. You have to update discriminator but during update the GAN network you have to keep the discriminator untrainable. so you first train the discriminator and then train the gan keeping discriminator untrainable. see this for more understanding: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3

like image 26
TBhavnani Avatar answered Sep 23 '22 13:09

TBhavnani