Multidimensional Gaussian Process Regression (GPR)

This script demonstrates the use of a Gaussian Process Regressor (GPR) with a Radial Basis Function (RBF) kernel in a multidimensional setting. The example involves training a GPR model on 2D input data and predicting the outputs on a test set.

MultidimensionalGPR
/home/runner/work/deep-learning-library/deep-learning-library/DLL/MachineLearning/SupervisedLearning/GaussianProcesses/_GaussianProcessRegressor.py:68: UserWarning: y should have zero mean for optimal results. Use DLL.Data.Preprocessing.StandardScaler to scale the data.
  warn("y should have zero mean for optimal results. Use DLL.Data.Preprocessing.StandardScaler to scale the data.")
Epoch: 1 - Log marginal likelihood: -960.4268798828125 - Parameters: {'rbf_sigma_3': 1.01, 'rbf_corr_len_3_1': 9.99, 'rbf_corr_len_3_2': 9.99, 'white_gaussian_sigma_1': 0.99}
Epoch: 101 - Log marginal likelihood: -854.9034423828125 - Parameters: {'rbf_sigma_3': 1.92, 'rbf_corr_len_3_1': 8.992, 'rbf_corr_len_3_2': 8.604, 'white_gaussian_sigma_1': 0.735}
Epoch: 201 - Log marginal likelihood: -746.0242309570312 - Parameters: {'rbf_sigma_3': 2.743, 'rbf_corr_len_3_1': 8.264, 'rbf_corr_len_3_2': 6.528, 'white_gaussian_sigma_1': 0.564}
Epoch: 301 - Log marginal likelihood: -220.94668579101562 - Parameters: {'rbf_sigma_3': 4.146, 'rbf_corr_len_3_1': 7.378, 'rbf_corr_len_3_2': 4.399, 'white_gaussian_sigma_1': 0.271}
Epoch: 401 - Log marginal likelihood: -109.2108154296875 - Parameters: {'rbf_sigma_3': 4.672, 'rbf_corr_len_3_1': 5.554, 'rbf_corr_len_3_2': 3.647, 'white_gaussian_sigma_1': 0.243}
Epoch: 501 - Log marginal likelihood: -14.64208984375 - Parameters: {'rbf_sigma_3': 5.051, 'rbf_corr_len_3_1': 3.599, 'rbf_corr_len_3_2': 3.223, 'white_gaussian_sigma_1': 0.225}
Epoch: 601 - Log marginal likelihood: 10.4254150390625 - Parameters: {'rbf_sigma_3': 5.231, 'rbf_corr_len_3_1': 2.688, 'rbf_corr_len_3_2': 2.932, 'white_gaussian_sigma_1': 0.224}
Epoch: 701 - Log marginal likelihood: 17.1297607421875 - Parameters: {'rbf_sigma_3': 5.327, 'rbf_corr_len_3_1': 2.279, 'rbf_corr_len_3_2': 2.715, 'white_gaussian_sigma_1': 0.224}
Epoch: 801 - Log marginal likelihood: 19.479248046875 - Parameters: {'rbf_sigma_3': 5.389, 'rbf_corr_len_3_1': 2.085, 'rbf_corr_len_3_2': 2.55, 'white_gaussian_sigma_1': 0.225}
Epoch: 901 - Log marginal likelihood: 20.521728515625 - Parameters: {'rbf_sigma_3': 5.435, 'rbf_corr_len_3_1': 1.997, 'rbf_corr_len_3_2': 2.427, 'white_gaussian_sigma_1': 0.224}

import torch
import matplotlib.pyplot as plt

from DLL.MachineLearning.SupervisedLearning.GaussianProcesses import GaussianProcessRegressor
from DLL.MachineLearning.SupervisedLearning.Kernels import RBF, WhiteGaussian
from DLL.Data.Preprocessing import data_split, StandardScaler
from DLL.DeepLearning.Optimisers import ADAM


n = 30
X, Y = torch.meshgrid(torch.linspace(0, 1, n, dtype=torch.float32), torch.linspace(-1, 1, n, dtype=torch.float32), indexing="xy")
x = torch.stack((X.flatten(), Y.flatten()), dim=1)
y = X.flatten() ** 2 + Y.flatten() ** 2 + 0.1 * torch.randn(size=Y.flatten().size()) - 5
x_train, y_train, _, _, x_test, y_test = data_split(x, y, train_split=0.8, validation_split=0.0)
transformer = StandardScaler()
y_train = transformer.fit_transform(y_train)
y_test = transformer.transform(y_test)


model = GaussianProcessRegressor(RBF(correlation_length=torch.tensor([10.0, 10.0],)) + WhiteGaussian())
model.fit(x_train, y_train)
optimizer = ADAM(0.01)
lml = model.train_kernel(epochs=1000, optimiser=optimizer, callback_frequency=100, verbose=True)["log marginal likelihood"]
mean, _ = model.predict(x_test)
z = transformer.inverse_transform(mean)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(121, projection='3d')
surf = ax.scatter(x_test[:, 0], x_test[:, 1], z, color="blue", label="prediction")
surf = ax.scatter(x_test[:, 0], x_test[:, 1], transformer.inverse_transform(y_test), color="red", label="true value")
ax.legend()

ax = fig.add_subplot(122)
ax.plot(lml)
ax.grid()
ax.set_xlabel("Epoch")
ax.set_ylabel("Log marginal likelihood")

plt.show()

Total running time of the script: (0 minutes 36.567 seconds)

Gallery generated by Sphinx-Gallery