.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples\DeepLearning\Attention.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_DeepLearning_Attention.py: Deep learning with Attention ============================== This script implements a model to predict a dummy dataset using MultiHeadAttention. The model has a similar structure to modern large language models, but with way less parameters. .. GENERATED FROM PYTHON SOURCE LINES 8-56 .. image-sg:: /auto_examples/DeepLearning/images/sphx_glr_Attention_001.png :alt: Attention :srcset: /auto_examples/DeepLearning/images/sphx_glr_Attention_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Model summary: Input - Output: (10) Dense - (Input, Output): (10, 99) - Parameters: 1089 ReLU - Output: (99) Reshape - (Input, Output): (99, (11, 9)) Layer list - (Input, Output): ((11, 9), (11, 9)) MultiHeadAttention - (Input, Output): ((11, 9), (11, 9)) Layer normalisation - Output: ((11, 9)) - Parameters: 22 Dense - (Input, Output): ((11, 9), (11, 9)) - Parameters: 90 ReLU - Output: ((11, 9)) Layer list - (Input, Output): ((11, 9), (11, 9)) MultiHeadAttention - (Input, Output): ((11, 9), (11, 9)) Layer normalisation - Output: ((11, 9)) - Parameters: 22 Dense - (Input, Output): ((11, 9), (11, 9)) - Parameters: 90 ReLU - Output: ((11, 9)) Layer list - (Input, Output): ((11, 9), (11, 9)) MultiHeadAttention - (Input, Output): ((11, 9), (11, 9)) Layer normalisation - Output: ((11, 9)) - Parameters: 22 Dense - (Input, Output): ((11, 9), (11, 9)) - Parameters: 90 ReLU - Output: ((11, 9)) Flatten - (Input, Output): ((11, 9), 99) Dense - (Input, Output): (99, 1) - Parameters: 100 Total number of parameters: 1525 Epoch: 1 - Metrics: {'loss': '120151.7969', 'val_loss': '114753.6016'} Epoch: 2 - Metrics: {'loss': '119704.5312', 'val_loss': '114314.1484'} Epoch: 3 - Metrics: {'loss': '119201.7344', 'val_loss': '113824.6406'} Epoch: 4 - Metrics: {'loss': '118604.6953', 'val_loss': '113247.4531'} Epoch: 5 - Metrics: {'loss': '117954.3828', 'val_loss': '112598.4062'} Epoch: 6 - Metrics: {'loss': '117161.3203', 'val_loss': '111806.2812'} Epoch: 7 - Metrics: {'loss': '116119.9453', 'val_loss': '110802.8906'} Epoch: 8 - Metrics: {'loss': '114822.7969', 'val_loss': '109570.4297'} Epoch: 9 - Metrics: {'loss': '113317.7188', 'val_loss': '108118.1797'} Epoch: 10 - Metrics: {'loss': '111563.2422', 'val_loss': '106392.8438'} Epoch: 11 - Metrics: {'loss': '109663.4531', 'val_loss': '104543.8594'} Epoch: 12 - Metrics: {'loss': '107544.6094', 'val_loss': '102440.6484'} Epoch: 13 - Metrics: {'loss': '104881.0859', 'val_loss': '99783.3984'} Epoch: 14 - Metrics: {'loss': '101934.4688', 'val_loss': '96954.0703'} Epoch: 15 - Metrics: {'loss': '99016.3125', 'val_loss': '94121.6797'} Epoch: 16 - Metrics: {'loss': '95894.1953', 'val_loss': '91092.3438'} Epoch: 17 - Metrics: {'loss': '92502.0625', 'val_loss': '87795.8984'} Epoch: 18 - Metrics: {'loss': '88817.6562', 'val_loss': '84209.4297'} Epoch: 19 - Metrics: {'loss': '84871.5938', 'val_loss': '80363.7500'} Epoch: 20 - Metrics: {'loss': '80713.7031', 'val_loss': '76318.1719'} Epoch: 21 - Metrics: {'loss': '76367.5859', 'val_loss': '72092.0078'} Epoch: 22 - Metrics: {'loss': '71826.1797', 'val_loss': '67687.6094'} Epoch: 23 - Metrics: {'loss': '67140.3047', 'val_loss': '63153.5859'} Epoch: 24 - Metrics: {'loss': '62422.2383', 'val_loss': '58586.2344'} Epoch: 25 - Metrics: {'loss': '57642.3281', 'val_loss': '53964.2344'} Epoch: 26 - Metrics: {'loss': '52779.7539', 'val_loss': '49273.7656'} Epoch: 27 - Metrics: {'loss': '47979.7031', 'val_loss': '44647.8008'} Epoch: 28 - Metrics: {'loss': '43282.6523', 'val_loss': '40131.3438'} Epoch: 29 - Metrics: {'loss': '38773.8203', 'val_loss': '35814.5859'} Epoch: 30 - Metrics: {'loss': '34371.3086', 'val_loss': '31607.8945'} Epoch: 31 - Metrics: {'loss': '30334.3965', 'val_loss': '27771.2441'} Epoch: 32 - Metrics: {'loss': '26518.8242', 'val_loss': '24164.0723'} Epoch: 33 - Metrics: {'loss': '23097.9746', 'val_loss': '20949.0684'} Epoch: 34 - Metrics: {'loss': '20035.9473', 'val_loss': '18084.5234'} Epoch: 35 - Metrics: {'loss': '17288.5176', 'val_loss': '15543.9863'} Epoch: 36 - Metrics: {'loss': '15061.9229', 'val_loss': '13511.8984'} Epoch: 37 - Metrics: {'loss': '13252.7373', 'val_loss': '11890.5850'} Epoch: 38 - Metrics: {'loss': '11754.1982', 'val_loss': '10576.2695'} Epoch: 39 - Metrics: {'loss': '10600.3438', 'val_loss': '9601.2051'} Epoch: 40 - Metrics: {'loss': '9748.0674', 'val_loss': '8915.9814'} Epoch: 41 - Metrics: {'loss': '9186.5303', 'val_loss': '8506.2422'} Epoch: 42 - Metrics: {'loss': '8847.7148', 'val_loss': '8306.1104'} Epoch: 43 - Metrics: {'loss': '8682.1396', 'val_loss': '8266.0088'} Epoch: 44 - Metrics: {'loss': '8646.6729', 'val_loss': '8334.3096'} Epoch: 45 - Metrics: {'loss': '8693.4912', 'val_loss': '8475.1729'} Epoch: 46 - Metrics: {'loss': '8789.5615', 'val_loss': '8651.8877'} Epoch: 47 - Metrics: {'loss': '8908.7119', 'val_loss': '8839.1221'} Epoch: 48 - Metrics: {'loss': '9027.6201', 'val_loss': '9009.4766'} Epoch: 49 - Metrics: {'loss': '9151.2285', 'val_loss': '9175.8516'} Epoch: 50 - Metrics: {'loss': '9227.2812', 'val_loss': '9270.0098'} Epoch: 51 - Metrics: {'loss': '9319.2080', 'val_loss': '9386.7529'} Epoch: 52 - Metrics: {'loss': '9412.2686', 'val_loss': '9503.1309'} Epoch: 53 - Metrics: {'loss': '9484.5420', 'val_loss': '9592.7100'} Epoch: 54 - Metrics: {'loss': '9550.4482', 'val_loss': '9674.5137'} Epoch: 55 - Metrics: {'loss': '9614.4785', 'val_loss': '9756.5186'} Epoch: 56 - Metrics: {'loss': '9682.2734', 'val_loss': '9843.1836'} Epoch: 57 - Metrics: {'loss': '9675.4824', 'val_loss': '9838.4775'} Epoch: 58 - Metrics: {'loss': '9682.9180', 'val_loss': '9846.7129'} Epoch: 59 - Metrics: {'loss': '9770.1221', 'val_loss': '9949.4365'} Epoch: 60 - Metrics: {'loss': '9754.7959', 'val_loss': '9929.6006'} Epoch: 61 - Metrics: {'loss': '9783.4854', 'val_loss': '9961.6396'} Epoch: 62 - Metrics: {'loss': '9805.7305', 'val_loss': '9986.9004'} Epoch: 63 - Metrics: {'loss': '9835.2031', 'val_loss': '10019.3027'} Epoch: 64 - Metrics: {'loss': '9814.9189', 'val_loss': '9992.7461'} Epoch: 65 - Metrics: {'loss': '9720.7480', 'val_loss': '9872.6475'} Epoch: 66 - Metrics: {'loss': '9685.3320', 'val_loss': '9827.2100'} Epoch: 67 - Metrics: {'loss': '9688.1553', 'val_loss': '9829.3340'} Epoch: 68 - Metrics: {'loss': '9715.8711', 'val_loss': '9863.3848'} Epoch: 69 - Metrics: {'loss': '9735.3330', 'val_loss': '9888.7578'} Epoch: 70 - Metrics: {'loss': '9773.6123', 'val_loss': '9942.9365'} Epoch: 71 - Metrics: {'loss': '9871.1162', 'val_loss': '10060.7793'} Epoch: 72 - Metrics: {'loss': '9954.1270', 'val_loss': '10166.9463'} Epoch: 73 - Metrics: {'loss': '9982.3018', 'val_loss': '10202.6953'} Epoch: 74 - Metrics: {'loss': '9913.5752', 'val_loss': '10125.6738'} Epoch: 75 - Metrics: {'loss': '9795.6309', 'val_loss': '9990.7285'} Epoch: 76 - Metrics: {'loss': '9712.0986', 'val_loss': '9888.5459'} Epoch: 77 - Metrics: {'loss': '9780.6934', 'val_loss': '9967.5156'} Epoch: 78 - Metrics: {'loss': '9852.0869', 'val_loss': '10058.8252'} Epoch: 79 - Metrics: {'loss': '9907.8232', 'val_loss': '10130.0205'} Epoch: 80 - Metrics: {'loss': '9903.3086', 'val_loss': '10128.5176'} Epoch: 81 - Metrics: {'loss': '9942.4824', 'val_loss': '10179.6357'} Epoch: 82 - Metrics: {'loss': '9894.7334', 'val_loss': '10125.7051'} Epoch: 83 - Metrics: {'loss': '9868.2002', 'val_loss': '10095.3818'} Epoch: 84 - Metrics: {'loss': '9847.4844', 'val_loss': '10067.8906'} Epoch: 85 - Metrics: {'loss': '9801.6484', 'val_loss': '10014.8174'} Epoch: 86 - Metrics: {'loss': '9799.7295', 'val_loss': '10009.1035'} Epoch: 87 - Metrics: {'loss': '9755.0400', 'val_loss': '9956.5781'} Epoch: 88 - Metrics: {'loss': '9753.7598', 'val_loss': '9961.4766'} Epoch: 89 - Metrics: {'loss': '9760.5273', 'val_loss': '9972.4043'} Epoch: 90 - Metrics: {'loss': '9734.5039', 'val_loss': '9938.7578'} Epoch: 91 - Metrics: {'loss': '9816.0869', 'val_loss': '10035.5830'} Epoch: 92 - Metrics: {'loss': '9871.4219', 'val_loss': '10103.0039'} Epoch: 93 - Metrics: {'loss': '9826.9258', 'val_loss': '10047.7354'} Epoch: 94 - Metrics: {'loss': '9796.4736', 'val_loss': '10011.3838'} Epoch: 95 - Metrics: {'loss': '9756.5088', 'val_loss': '9961.4307'} Epoch: 96 - Metrics: {'loss': '9807.9854', 'val_loss': '10025.3838'} Epoch: 97 - Metrics: {'loss': '9760.7334', 'val_loss': '9948.2402'} Epoch: 98 - Metrics: {'loss': '9782.9658', 'val_loss': '9970.3867'} Epoch: 99 - Metrics: {'loss': '9759.0879', 'val_loss': '9942.3311'} Epoch: 100 - Metrics: {'loss': '9787.1035', 'val_loss': '9976.9678'} Test mean squared error: 8607.1103515625 | .. code-block:: Python import torch import matplotlib.pyplot as plt from DLL.DeepLearning.Model import Model from DLL.DeepLearning.Layers import MultiHeadAttention, Dense, Flatten, Reshape, LayerList from DLL.DeepLearning.Layers.Regularisation import LayerNorm from DLL.DeepLearning.Layers.Activations import ReLU from DLL.DeepLearning.Optimisers import ADAM from DLL.DeepLearning.Losses import MSE from DLL.Data.Preprocessing import data_split from DLL.Data.Metrics import mean_squared_error n = 1000 seq_len = 10 X = 10 * torch.rand((n, seq_len)) y = (X ** 2).sum(dim=1) X_train, y_train, X_val, y_val, X_test, y_test = data_split(X, y, 0.6, 0.2) block = LayerList( MultiHeadAttention((11, 9), n_heads=3, dropout=0.5), LayerNorm(), Dense((11, 9)), ReLU() ) model = Model((seq_len,)) model.add(Dense(99, activation=ReLU())) model.add(Reshape((11, 9))) model.add(block.clone()) model.add(block.clone()) model.add(block.clone()) model.add(Flatten()) model.add(Dense(tuple())) model.compile(ADAM(), MSE(), metrics=["loss", "val_loss"]) model.summary() history = model.fit(X_train, y_train, val_data=(X_val, y_val), epochs=100, callback_frequency=1, batch_size=64, verbose=True) y_pred = model.predict(X_test) print(f"Test mean squared error: {mean_squared_error(y_pred, y_test)}") plt.figure(figsize=(8, 8)) plt.semilogy(history["loss"], label="loss") plt.semilogy(history["val_loss"], label="validation loss") plt.legend() plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 25.533 seconds) .. _sphx_glr_download_auto_examples_DeepLearning_Attention.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: Attention.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: Attention.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: Attention.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_