Linear Attention Sequence Parallelism
arxiv(2024)
摘要
Sequence Parallel (SP) serves as a prevalent strategy to handle long
sequences that exceed the memory limit of a single GPU. However, existing SP
methods do not take advantage of linear attention features, resulting in
sub-optimal parallelism efficiency and usability for linear attention-based
language models. In this paper, we introduce Linear Attention Sequence Parallel
(LASP), an efficient SP method tailored to linear attention-based language
models. Specifically, we design an efficient point-to-point communication
mechanism to leverage the right-product kernel trick of linear attention, which
sharply decreases the communication overhead of SP. We also enhance the
practical efficiency of LASP by performing kernel fusion and intermediate state
caching, making the implementation of LASP hardware-friendly on GPU clusters.
Furthermore, we meticulously ensure the compatibility of sequence-level LASP
with all types of batch-level data parallel methods, which is vital for
distributed training on large clusters with long sequences and large batches.
We conduct extensive experiments on two linear attention-based models with
varying sequence lengths and GPU cluster sizes. LASP scales sequence length up
to 4096K using 128 A100 80G GPUs on 1B models, which is 8 times longer than
existing SP methods while being significantly faster. The code is available at
https://github.com/OpenNLPLab/LASP.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要