Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Gaussian process with 2D feature array as input - scikit-learn

I need to implement GPR (Gaussian process regression) in Python using the scikit-learn library.

My input X has two features. Ex. X=[x1, x2]. And output is one dimension y=[y1]

I want to use two Kernels; RBF and Matern, such that RBF uses the 'x1' feature while Matern use the 'x2' feature. I tried the following:

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern as M, RBF as R

X = np.matrix([[1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.], [1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.],[1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.]]).T

y=[0.84147098,  0.42336002, -4.79462137, -1.67649299,  4.59890619,  7.91486597, 0.84147098,  0.42336002, -4.79462137, -1.67649299,  4.59890619,  7.91486597, 0.84147098,  0.42336002, -4.79462137, -1.67649299,  4.59890619,  7.91486597]

kernel = R(X[0]) * M(X[1])
gp = GaussianProcessRegressor(kernel=kernel)

gp.fit(X, y)

But this gives an error

ValueError: Found input variables with inconsistent numbers of samples: [2, 18]

I tried several methods but could not find a solution. Really appreciate if someone can help.

like image 520
QuantumGirl Avatar asked Jun 06 '18 07:06

QuantumGirl


1 Answers

Your X should not be a matrix, but an array of 2D elements:

X = np.array([[1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.], [1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.],[1.,2], [3.,4], [5.,1], [6.,5],[4, 7.],[ 9,8.]])

# rest of your code as is

gp.fit(X, y)

# result:

GaussianProcessRegressor(alpha=1e-10, copy_X_train=True,
             kernel=RBF(length_scale=[1, 2]) * Matern(length_scale=[3, 4], nu=1.5),
             n_restarts_optimizer=0, normalize_y=False,
             optimizer='fmin_l_bfgs_b', random_state=None)

That said, your kernel definition will not do what you want to do; most probably you have to change it to

kernel = R([1,0]) * M([0,1]) 

but I am not quite sure about that - be sure to check the documentation for the correct arguments of the RBF and Matern kernels...

like image 107
desertnaut Avatar answered Oct 22 '22 05:10

desertnaut