I have a Keras Model
which calculates two tensors, r1
and r2
of the same shape. I would like to have the model calculate (r1 - r2)**2
.
I can take the sum of these tensors with keras.layers.add(r1, r2)
. I can take a product with keras.layers.multiply(r1, r2)
. If there was a subtract
function, I'd write
r = keras.layers.subtract(r1, r2)
square_diff = keras.layers.multiply(r, r)
but there doesn't appear to be a keras.layers.subtract
function.
In lieu of that I've been trying to figure out how to multiply one of my inputs by a constant -1 tensor and then adding, but I can't figure out how to create that -1 tensor. I've tried a number of variants on
negative_one = keras.backend.constant(np.full(r1.get_shape()), -1)
none of which work. Presumably because the dimensionality of r1
is (?, 128)
(i.e. the first dimension is a batch size, and the second represents 128 hidden elements.)
What is the correct way in Keras to take the difference of two tensors?
As dhinckley mentioned, you should use Lambda layer. But I would suggest to define your custom function first. With this code will a little bit more clear:
import keras.backend as K
from keras.layers import Lambda
def squared_differences(pair_of_tensors):
x, y = pair_of_tensors
return K.square(x - y)
square_diff = Lambda(squared_differences)([r1, r2])
I'm not qualified to say whether or not this is the correct way, but the following code will calculate (r1 - r2)**2
as you request. The key enabler here is the use of the Keras functional API and Lambda layers to invert the sign of an input tensor.
import numpy as np
from keras.layers import Input, Lambda
from keras.models import Model
from keras.layers import add
r1 = Input(shape=(1,2,2))
r2 = Input(shape=(1,2,2))
# Lambda for subtracting two tensors
minus_r2 = Lambda(lambda x: -x)(r2)
subtracted = add([r1,minus_r2])
out= Lambda(lambda x: x**2)(subtracted)
model = Model([r1,r2],out)
a = np.arange(4).reshape([1,1,2,2])
b = np.ones(4).reshape([1,1,2,2])
print(model.predict([a,b]))
# [[[[ 1. 0.]
# [ 1. 4.]]]]
print((a-b)**2)
# [[[[ 1. 0.]
# [ 1. 4.]]]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With