Just read twice: closing the recall gap for recurrent language models
arxiv(2024)
摘要
Recurrent large language models that compete with Transformers in language
modeling perplexity are emerging at a rapid rate (e.g., Mamba, RWKV).
Excitingly, these architectures use a constant amount of memory during
inference. However, due to the limited memory, recurrent LMs cannot recall and
use all the information in long contexts leading to brittle in-context learning
(ICL) quality. A key challenge for efficient LMs is selecting what information
to store versus discard. In this work, we observe the order in which
information is shown to the LM impacts the selection difficulty. To formalize
this, we show that the hardness of information recall reduces to the hardness
of a problem called set disjointness (SD), a quintessential problem in
communication complexity that requires a streaming algorithm (e.g., recurrent
model) to decide whether inputted sets are disjoint. We empirically and
theoretically show that the recurrent memory required to solve SD changes with
set order, i.e., whether the smaller set appears first in-context. Our analysis
suggests, to mitigate the reliance on data order, we can put information in the
right order in-context or process prompts non-causally. Towards that end, we
propose: (1) JRT-Prompt, where context gets repeated multiple times in the
prompt, effectively showing the model all data orders. This gives 11.0 ±
1.3 points of improvement, averaged across 16 recurrent LMs and the 6 ICL
tasks, with 11.9× higher throughput than FlashAttention-2 for generation
prefill (length 32k, batch size 16, NVidia H100). We then propose (2)
JRT-RNN, which uses non-causal prefix-linear-attention to process prompts and
provides 99% of Transformer quality at 360M params., 30B tokens and
96% at 1.3B params., 50B tokens on average across the tasks, with
19.2× higher throughput for prefill than FA2.
更多查看译文
AI 理解论文
溯源树
样例
![](https://originalfileserver.aminer.cn/sys/aminer/pubs/mrt_preview.jpeg)
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要