Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch to Keras code equivalence

Tags:

keras

pytorch

Given a below code in PyTorch what would be the Keras equivalent?

class Network(nn.Module):

    def __init__(self, state_size, action_size):
        super(Network, self).__init__()

        # Inputs = 5, Outputs = 3, Hidden = 30
        self.fc1 = nn.Linear(5, 30)
        self.fc2 = nn.Linear(30, 3)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        outputs = self.fc2(x)
        return outputs

Is it this?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='relu'))
model.add(Dense(units=30, activation='relu'))
model.add(Dense(units=3, activation='linear'))

or is it this?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='linear'))
model.add(Dense(units=30, activation='relu'))
model.add(Dense(units=3, activation='linear'))

or is it?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='relu'))
model.add(Dense(units=30, activation='linear'))
model.add(Dense(units=3, activation='linear'))

Thanks

like image 737
Milind Dalvi Avatar asked Oct 21 '17 18:10

Milind Dalvi


Video Answer


1 Answers

None of them looks correct according to my knowledge. A correct Keras equivalent code would be:

model = Sequential()
model.add(Dense(30, input_shape=(5,), activation='relu')) 
model.add(Dense(3)) 

model.add(Dense(30, input_shape=(5,), activation='relu'))

Model will take as input arrays of shape (*, 5) and output arrays of shape (*, 30). Instead of input_shape, you can use input_dim also. input_dim=5 is equivalent to input_shape=(5,).

model.add(Dense(3))

After the first layer, you don't need to specify the size of the input anymore. Moreover, if you don't specify anything for activation, no activation will be applied (equivalent to linear activation).


Another alternative would be:

model = Sequential()
model.add(Dense(30, input_dim=5)) 
model.add(Activation('relu'))
model.add(Dense(3)) 

Hopefully this makes sense!

like image 188
Wasi Ahmad Avatar answered Sep 28 '22 19:09

Wasi Ahmad