Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Unable to interpret a line of python code that creates a LSTM cell using tensorflow

I am trying to figure out how a fully-functional python code works. One block creates a LSTM cell using tensorflow. I don't know how to interpret the line specified by the comment below.

def get_lstm_weights(n_hidden, forget_bias, dim, scope="rnn_cell"):
    # Create LSTM cell
    cell = tf.contrib.rnn.LSTMCell(num_units = n_hidden, reuse=None, forget_bias = forget_bias)
    #--------------------------------------
    # I DO NOT UNDERSTAND THE NEXT LINE
    cell(tf.zeros([1, dim +1]), (tf.zeros([1, n_hidden]),tf.zeros([1, n_hidden])), scope=scope)
    # -------------------------------------
    cell = tf.contrib.rnn.LSTMCell(num_units = n_hidden, reuse=True, forget_bias = forget_bias)

    # Create output weights
    weights = {
        'W_1': tf.Variable(tf.truncated_normal([n_hidden, dim], stddev=0.05)),
        'b_1': tf.Variable(0.1*tf.ones([dim])),
    }

    return cell, weights
like image 532
mo adib Avatar asked Jan 22 '26 12:01

mo adib


1 Answers

Note that tf.contrib.rnn.LSTMCell is an example of a callable class.

That is a class that can be called like a function. The line you are struggling with does exactly that. It calls cell with the parameters in brackets.

If you want to see what this does you can inspect the __call__ method on the class definition for tf.contrib.rnn.LSTMCell

like image 119
Stewart_R Avatar answered Jan 24 '26 02:01

Stewart_R



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!