Paper-Weekly16-Retro-Improving Language Models by Retrieving from Trillions of Tokens

This paper introduces Retrieval-Enhanced Transformer (Retro) that enhances auto-regressive language models by conditioning on document chunks retrieved from a large corpus.

Intro

Retro achieves comparable performance to GPT-3 and Jurassic-1 on the Pile, despite using 25x fewer parameters. It combines a frozen Bert retriever, a differentiable encoder, and a chunked cross-attention mechanism to predict tokens based on an order of magnitude more data than what is typically consumed during training. Retro can be trained from scratch or rapidly retrofit pre-trained transformers with retrieval and still achieve good performance. This approach opens up new avenues for improving language models through explicit memory at an unprecedented scale.

Scaling the training data to trillions of tokens improves the performance of language models in machine translation and downstream tasks. Later research demonstrates that large language models can even memorize parts of their training data, which suggests that enhancing models with retrieval may lead to further improvements. Information retrieval for text has historically relied on inverted index matching and latent topic modeling approaches. However, with the success of deep learning, retrieval systems have shifted to using dense learned representations based on neural networks. DPR, FiD and RAG are retrieval systems that improve question answering benchmarks by training encoder-decoder transformer models. Retro shares components with KNN-LM and DPR, but models longer sequences and retrieves different documents for different chunks of a sequence. Notably, Retro allows for repeated retrieval while generating a sequence and performs retrieval throughout the pre-training process, unlike other methods.

Training Dataset

The researchers used a multi-lingual version of MassiveText for training and retrieval data, which consisted of text documents from multiple sources and languages, totaling over 5 trillion tokens. They tokenized the dataset using SentencePiece with a vocabulary of 128,000 tokens and retrieved from 600B tokens from the training data during training. To prevent test set leakage, they used the MinHash scheme to compute the 13-gram Jaccard similarity between train and test documents and removed training documents with high similarity to validation or test set documents. They also removed validation and test articles from Wikitext103 from their Wikipedia training data.

Nearest Neighbor Retrieval

The training dataset consists of a key-value memory, and each value consists of two contiguous chunks of tokens $N$ and $F$ where $N$ is used for computing the BERT embedding as the key and $F$ is the continuation in the original document. They use SCaNN library which can query approximate neighbors in $O(logN)$ time over pre-compute BERT embeddings, enabling on-the-fly retrieval during training.

Retro Model Architecture

The model architecture is an encoder-decoder transformer relying on the cross-attention mechanism from \citet{vaswani2017attention}. Retrieved tokens are initially processed by an encoder Transformer to generate an encoded neighbors set $E$. These blocks consist of three residual operators: a fully-connected layer FFW, a standard sequence-level self-attention layer Attn, and a chunked cross-attention layer CCA integrating information from the retrieval encoder.

Chunked Cross-Attention

To perform CCA operation, they first split the given intermediate activation into $l-1$ chunks, denoting $H^+$ . Each chunk consists of the last token in chunk $C_u$ and the first m-1 tokens in chunk $C_{u+1}$ . The first m-1 tokens can not attend to any neighbor of a previous chunk, so they define CCA as the identity, setting $\text{CCA}(H,E)_j \triangleq h_j$ for $j \in \text{[1,m-1]}$ . For the last token, they set $h_{lm} \triangleq \text{CA}(h_{lm},E_l)$ . Note that chunked cross-attention designed above is autoregressive: the output of CCA at position i depends on the sequence from tokens from 0 to i-1 that is input to CCA. When sampling, the nearest neighbors are retrieved at the end of each chunk based on the BERT embedding of the chunk, the neighbors are then used for the generation in next chunk making it linear in the number of chunks, in contrast to the quadratic cost in most transformer architectures.

Evaluation and Results

The researchers first use the longest common sub-string to evaluate the overlap between training data and evaluation data, they use the ratio $\alpha$ between the chunk length and the sub-string length to filter out chunks that are potentially leaked during training. Then they adopt bits-per-byte to evaluate the performance of language models. For any given language models, we can obtain the (negative) log-likelihood l of each chunk, and the number of bytes N it encodes.

$$\forall\alpha\in\text{[0,1]},C_\alpha\triangleq {C\in\mathbf{C}, r(C)\leq\alpha},\\ bpb(\alpha)\triangleq\frac{\sum_{C\in\mathbf{C}}l(C)}{\sum_{C\in\mathbf{C}}N(C)}$$

This metric is widely adopted by many work focusing on language modeling, since the bytes in the denominator during training is constant, the lower bpb gets the better model is trained.

The evaluation is conducted on various datasets including C4, Wikitext103, Curation Corpus, Lambada, and the Pile. The language modelling performance is measured using bits-per-byte and shows that the retro outperform the baseline models on all datasets. The performance gains are more significant on Wikitext103 and C4 datasets, which have overlaps with the retrieval model. Increasing the size of the retrieval database and the number of retrieved chunks improves the language modelling performance. The study compares the performance of the retrofit model with Jurassic-1 and Gopher models on the Pile test sets and shows that the retrofit model outperforms the baseline and Jurassic-1 on most test sets. However, the retro model underperforms on dm_mathematics and ubuntu_irc subsets, possibly due to ineffective retrieved neighbors.

陈沁宇
陈沁宇
Master Student@PKU

My research interests include natural language processing, machine learning and recommender systems.