Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Memory Management about tf.Model in TensorFlow.js

I'm a newbie in TensorFlow.

The "Memory Management: dispose and tf.tidy" section in https://js.tensorflow.org/tutorials/core-concepts.html says that we have to manage memories in the special way.

However, the classes in tfjs-layers (e.g. tf.Model and Layer) don't seem to have dispose and tf.tidy doesn't accept those as returned values.

So my questions are:

  • Does tf.Model manage memories automatically?
  • If no, how can I manage memories correctly?

Sample code:

function defineModel(
    regularizerRate: number,
    learningRate: number,
    stateSize: number,
    actionSize: number,
): tf.Model {
    return tf.tidy(() => { // Compile error here, I couldn't return model.
        const input = tf.input({
            name: "INPUT",
            shape: [stateSize],
            dtype: "int32" as any, // TODO(mysticatea): https://github.com/tensorflow/tfjs/issues/120
        })
        const temp = applyHiddenLayers(input, regularizerRate)
        const valueOutput = applyValueLayer(temp, regularizerRate)
        const policyOutput = applyPolicyLayer(temp, actionSize, regularizerRate)
        const model = tf.model({
            inputs: [input],
            outputs: [valueOutput, policyOutput],
        })

        // TODO(mysticatea): https://github.com/tensorflow/tfjs/issues/98
        model.compile({
            optimizer: tf.train.sgd(LEARNING_RATE),
            loss: ["meanSquaredError", "meanSquaredError"],
        })
        model.lossFunctions[1] = softmaxCrossEntropy

        return model
    })
}
like image 646
mysticatea Avatar asked Mar 06 '23 11:03

mysticatea


1 Answers

You should only use tf.tidy() when directly manipulating tensors.

When you are building a model, you are not yet directly manipulating tensors, rather you are setting up the structure of how layers fit together. This means you don't need to wrap your model creation in a tf.tidy().

Only when you call "predict()" or "fit()" do we deal with concrete Tensor values and need to deal with memory management.

When "predict()" is called, it returns a Tensor, which you must dispose, or surround with a "tidy()".

In the case of "fit()", internally we do all the memory management for you. The return value of "fit()" are plain numbers, so you do not need to wrap it in a "tidy()".

like image 67
Nikhil Avatar answered Mar 23 '23 12:03

Nikhil