Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Regress the max function over a neural network

I'm training myself over learning neural network. There is a function that I can't make my neural network learn: f(x) = max(x_1, x_2). It seems like a very simple function with 2 inputs and 1 input but yet a 3 layer neural network trained over a thousand sample with 2000 epochs get it completly wrong. I'm using deeplearning4j.

Is there any reason why the max function would be very hard to learn for a neural network or am I just tuning it wrong?

like image 864
Atol Avatar asked Sep 12 '25 06:09

Atol


1 Answers

Just wanted to point out: If you use relu instead of tanh than there is actually an exact solution, and I guess if you would shrink down the network to this exact same size (1 hidden layer with 3 nodes), you would always end up with these weights (module permutations of nodes and scaling of weights (first layer scaled by gamma, second by 1/gamma)):

max(a,b) = ((1, 1, -1)) * relu( ((1,-1), (0,1), (0,-1)) * ((a,b)) )

where * is the matrix multiplication.

This equation translates the following human-readable version into NN-language:

max(a,b) = relu(a-b) + b = relu(a-b) + relu(b) - relu(-b)

I have not actually tested it, my point is, that it should theoretically be very easy for networks to learn this function.

EDIT: I just tested this and the result was as I expected it:

[[-1.0714666e+00 -7.9943770e-01  9.0549403e-01]
 [ 1.0714666e+00 -7.7552663e-08  2.6146751e-08]]

and

[[ 0.93330014]
 [-1.250879  ]
 [ 1.1043695 ]]

where the corresponding first and second layer. Transposing the second and multiplying with the first set of weights one ends up with a normalized version which can be compared to my theoretic results very easily:

[[-9.9999988e-01  9.9999988e-01  1.0000000e+00]
 [ 9.9999988e-01  9.7009000e-08  2.8875675e-08]]
like image 197
David S. Avatar answered Sep 14 '25 20:09

David S.