Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Non linear Regression: Why isn't the model learning?

I just started learning keras. I am trying to train a non-linear regression model in keras but model doesn't seem to learn much.

#datapoints
X = np.arange(0.0, 5.0, 0.1, dtype='float32').reshape(-1,1)
y = 5 * np.power(X,2) + np.power(np.random.randn(50).reshape(-1,1),3)

#model
model = Sequential()
model.add(Dense(50, activation='relu', input_dim=1))
model.add(Dense(30, activation='relu', init='uniform'))
model.add(Dense(output_dim=1, activation='linear'))

#training
sgd = SGD(lr=0.1);
model.compile(loss='mse', optimizer=sgd, metrics=['accuracy'])
model.fit(X, y, nb_epoch=1000)

#predictions
predictions = model.predict(X)

#plot
plt.scatter(X, y,edgecolors='g')
plt.plot(X, predictions,'r')
plt.legend([ 'Predictated Y' ,'Actual Y'])
plt.show()

enter image description here

what am I doing wrong?

like image 517
red5pider Avatar asked Feb 22 '18 18:02

red5pider


People also ask

Is non linear regression a learning machine?

Nonlinear regression is a statistical technique that helps describe nonlinear relationships in experimental data. Nonlinear regression models are generally assumed to be parametric, where the model is described as a nonlinear equation. Typically machine learning methods are used for non-parametric nonlinear regression.

Can a regression model be non linear?

Nonlinear regression is a mathematical model that fits an equation to certain data using a generated line. As is the case with a linear regression that uses a straight-line equation (such as Ỵ= c + m x), nonlinear regression shows association using a curve, making it nonlinear in the parameter.

Whats the difference between linear and non linear machine learning model?

Linear algorithms assume, that the sample features x and the label output y are linearly related and there is an affine function f(x) = \langle w, x \rangle + b describing the underlying relationship. Nonlinear algorithms assumes a nonlinear relationship between x and y.

Why is non linear regression better than linear regression?

The nonlinear model provides a better fit because it is both unbiased and produces smaller residuals. Nonlinear regression is a powerful alternative to linear regression but there are a few drawbacks. Fortunately, it's not difficult to try linear regression first.


1 Answers

Your learning rate is way too high.

Also, irrelevant to your issue, but you should not ask for metrics=['accuracy'], as this is a regression setting and accuracy is meaningless.

So, with these changes:

sgd = SGD(lr=0.001);
model.compile(loss='mse', optimizer=sgd)

plt.legend([ 'Predicted Y' ,'Actual Y']) # typo in legend :)

here are some outputs (results will be different among runs, due to the random element of your y):

enter image description here

enter image description here

like image 142
desertnaut Avatar answered Sep 18 '22 18:09

desertnaut