The I/O Complexity of Attention, or How Optimal is Flash Attention?
CoRR(2024)
摘要
Self-attention is at the heart of the popular Transformer architecture, yet
suffers from quadratic time and memory complexity. The breakthrough
FlashAttention algorithm revealed I/O complexity as the true bottleneck in
scaling Transformers. Given two levels of memory hierarchy, a fast cache (e.g.
GPU on-chip SRAM) and a slow memory (e.g. GPU high-bandwidth memory), the I/O
complexity measures the number of accesses to memory. FlashAttention computes
attention using N^2d^2/M I/O operations where N is the dimension of
the attention matrix, d the head-dimension and M the cache size. However,
is this I/O complexity optimal? The known lower bound only rules out an I/O
complexity of o(Nd) when M=Θ(Nd), since the output that needs to be
written to slow memory is Ω(Nd). This leads to the main question of our
work: Is FlashAttention I/O optimal for all values of M?
We resolve the above question in its full generality by showing an I/O
complexity lower bound that matches the upper bound provided by FlashAttention
for any values of M ≥ d^2 within any constant factors. Further, we give a
better algorithm with lower I/O complexity for M < d^2, and show that it is
optimal as well. Moreover, our lower bounds do not rely on using combinatorial
matrix multiplication for computing the attention matrix. We show even if one
uses fast matrix multiplication, the above I/O complexity bounds cannot be
improved. We do so by introducing a new communication complexity protocol for
matrix compression, and connecting communication complexity to I/O complexity.
To the best of our knowledge, this is the first work to establish a connection
between communication complexity and I/O complexity, and we believe this
connection could be of independent interest and will find many more
applications in proving I/O complexity lower bounds in the future.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要