DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
arxiv(2024)
摘要
Given the increasing demand for tree-structured interactions with LLMs, we
introduce DeFT (Decoding with Flash Tree-Attention), an IO-aware tree attention
algorithm tailored for tree-structured inference. Unlike traditional
sequence-based decoding, tree-structured decoding better accommodates modern
task requirements, including self-consistency, few-shot prompting, multi-step
reasoning, and multi-model/head coordination. However, existing sequence-based
inference systems are ill-suited for tree-structured decoding, resulting in
redundancy in computation, memory footprints, and memory access, thereby
undermining inference efficiency. To address this challenge, DeFT maintains
memory-efficient attention calculation with low memory footprints through two
key stages: (1) QKV Preparation: We propose a KV-Guided Grouping Strategy with
Tree Split to intelligently group QKV, optimizing GPU resource utilization
while minimizing memory reads/writes for KV cache between GPU global memory and
on-chip shared memory; (2)Attention Calculation: We compute partial attention
of each QKV group in a fused kernel and employ a Tree-topology-aware Global
Reduction strategy to obtain final attention. By reducing 73-99
and nearly 100
Softmax), DeFT achieves up to 2.52/3.82x speedup in the end-to-end/attention
latency across three practical tree-based workloads: namely, few-shot
prompting, multi-step reasoning, and speculative decoding, over
state-of-the-art attention algorithms.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要