Learning Associative Memories with Gradient Descent

This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory an...

Full description

Saved in:
Bibliographic Details
Main Authors: Cabannes, Vivien, Simsek, Berfin, Bietti, Alberto
Format: Journal Article
Language:English
Published: 28-02-2024
Subjects:
Online Access:Get full text
Tags: Add Tag
No Tags, Be the first to tag this record!
Description
Summary:This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.
DOI:10.48550/arxiv.2402.18724