Locally Weighted Regression on a Sine Function

This script demonstrates Locally Weighted Regression (LWR), using a Gaussian kernel to assign weights to training samples based on their distance from a test point.

Tau = 0.203 - R2 = 0.906, Tau = 0.251 - R2 = 0.904, Tau = 0.31 - R2 = 0.901, Tau = 0.384 - R2 = 0.894, Tau = 0.475 - R2 = 0.881, Tau = 0.587 - R2 = 0.855, Tau = 0.727 - R2 = 0.807, Tau = 0.899 - R2 = 0.725, Tau = 1.112 - R2 = 0.61, Tau = 1.376 - R2 = 0.495, Tau = 1.702 - R2 = 0.408, Tau = 2.106 - R2 = 0.348, Tau = 2.605 - R2 = 0.301, Tau = 3.223 - R2 = 0.263, Tau = 3.988 - R2 = 0.232, Tau = 4.933 - R2 = 0.207
import torch
import matplotlib.pyplot as plt

from DLL.MachineLearning.SupervisedLearning.LinearModels import LinearRegression
from DLL.Data.Metrics import r2_score
from DLL.Data.Preprocessing import data_split


X = torch.linspace(0, 2 * torch.pi, 1000).unsqueeze(1)
y = torch.sin(X).squeeze() + 0.2 * torch.randn_like(X.squeeze())
X_train, y_train, X_test, y_test, _, _ = data_split(X, y)

def get_weight(train, test, tau):
    d2 = torch.sum((train - test) ** 2, dim=1)
    w = torch.exp(-d2 / (2. * tau * tau))
    return w

def get_pred(tau):
    y_pred = []
    for test_point in X_test:
        weight = get_weight(X_train, test_point, tau)
        model = LinearRegression()
        model.fit(X_train, y_train, sample_weight=weight)
        y_pred.append(model.predict(test_point.reshape(1, -1))[0])

    y_pred = torch.stack(y_pred).reshape(-1,)
    return y_pred

n = 4
m = 4
fig, axes = plt.subplots(m, n, figsize=(15, 10))
plt.subplots_adjust(hspace=0.5)
axes = axes.ravel()
taus = torch.logspace(torch.log(torch.Tensor([0.5]).squeeze()), torch.log(torch.Tensor([2.0]).squeeze()), m * n)
for i, ax in enumerate(axes):
    y_pred = get_pred(taus[i])
    ax.set_title(f"Tau = {round(taus[i].item(), 3)} - R2 = {round(r2_score(y_test, y_pred), 3)}")
    ax.scatter(X_test, y_pred, s=10, c="r", label="prediction")
    ax.scatter(X_test, y_test, s=10, c="b", label="true")
    ax.grid()
plt.show()

Total running time of the script: (0 minutes 9.062 seconds)

Gallery generated by Sphinx-Gallery