登录    注册      
    
  


News Message

LLM MQA GQA



LLM MQA  GQA

用户发布的文档

加载速度比较慢比较慢,请稍等,手机环境下,有可能无法显示!


GPT,也就是 Transformer Decoder 结构做文本生成时有一个致命问题。先来看看 Encoder 推理是怎么做的,每个 timestep 都能看到所有 timestep ,推理时所有 timestep 一层层向后计算,一把过。于是内存相关开销就是() , 而计算相关开销就是(2) ,其中 N 为序列长度。

而 Decoder 推理时,最大不同在于自回归结构,可以看到图中每个 timestep 的输出都是下一 timestep 的输入,所以无法像 Encoder 一样一次过,每次都要 attend 之前的所有 timestep.

同样计算一下开销,计算开销是 1+(1+2)+(1+2+3)+...+(1+2+...+) 也就是(3) ,而内存开销则是 (2).

大家用 ChatGPT 接口也会有类似感觉,Context 部分成本很低,也很快,因为它做的类似于 Encoder 的并行。主要成本在生成那块,速度较慢,但也已经是优化过后的了。

下面就来讲讲优化方法。

KV Cache

Decoder 每次前向,当前 timestep 计算 Attention 要用到的部分,如之前 timestep 的 KV (Key 和 Value)值都计算过的,只是之前每次前向完后给计算结果都丢掉,只保留最后输出。

于是一个很自然的想法就是 Cache。这很像斐波那契递归函数,naive 版本,也会出现不断重复计算问题,加个 cache 瞬间提速。

每次前向完,给 KV 都保留下来,用于之后计算

代码表示如下:

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
#q、k、v 当前 timestep 的 query,key,value
# K_prev,V_prev 之前所有 timestep 的 key 和 value
for _ in range(time_step):
...
K = torch.cat([K_prev, k], dim=-2) #[b, h, n, d]
V = torch.cat([V_prev, v], dim=-2) #[b, h, n, d]

logits = torch.einsum("bhd,bhnd->bhn", q, K)
weights = torch.softmax(logits/math.sqrt(d), dim=-1)
outs = torch.einsum("bhn,bhnd->bhd", weights, V)
...

K_prev, V_prev = K, V

于是 Decoder 就被优化成,计算开销变成了(2),存储复杂度则是 (),只给 K 和 V 不断保存在缓存中就行。问题解决了!

但残酷现实会立马跳出来给你一棒子,上面假设 K 和 V 能直接存在缓存中,模型规模小还好,一旦模型规模很大长度很长时,KV 根本就存不进缓存

比如 Llama 7B 模型,hidden size 是 4096,那么每个 timestep 需缓存参数量为 4096232=262144,假设半精度保存就是 512KB,1024 长度那就要 512MB. 而现在英伟达最好的卡 H100 的 SRAM 缓存大概是 50MB,而 A100 则是 40MB. 而 7B 模型都这样,175B 模型就更不用说了。

那为什么我们不直接做大 SRAM 内存呢,不就直接解决问题了吗,但是这样又会产生一个新问题 SRAM 太贵了,所以这条路现在是不太行的。

于是退一步,放不进缓存可以放 DRAM 上去,而 DRAM 内存也就是我们常说的 GPU 显存。

但 DRAM 读取到计算芯片和 SRAM 到计算芯片的速度,差了一个量级的,这会让计算芯片一直在等待。

SRAM是静态随机存储器,速度非常快,但成本较高。DRAM是动态随机存储器,成本较低,但速度比SRAM慢

现在我们遇到了当今芯片领域,冯诺依曼架构下最大的一个问题,也就是:Memory Wall(内存墙)

冯诺依曼架构和 Memory Wall

冯诺依曼架构熟悉有计算机相关基础的,应该都稔熟于胸。输入,输出,计算单元,加上存储单元。

现在随着摩尔定律的见顶,虽然计算和内存的发展速度在变缓,但这并不是最大的问题,最大的问题是存储单元计算单元间的交互。

冯诺依曼架构需要先从内存中调取数据,送入计算单元进行处理,但现在计算单元的速度是显著提升的,而从内存中读取数据的速度却没跟上,所以计算和内存这里就形成了一个瓶颈。因为短板效应,内存读取速度限制了整体速度。计算单元能很快将数据处理完,但新数据却还没到,于是就只能等待,造成利用率不高。这就是内存墙

因为内存墙问题,现在 GPU,一张 A100 卡计算单元的利用率到四五十就不错了,用上各种技巧优化到 60% 已经很高了。而对于 H100 卡问题会更严重,因为它的计算速度相对 A100 提高了 6 倍,而内存读取带宽只增加了 1.6 倍,所以也要大量优化来提高利用率。

内存墙怎么越过呢?

硬件层面上,比如现在已在使用的 HBM(高速带宽内存)提高读取速度,或者更彻底些,抛弃冯诺依曼架构,改变计算单元从内存读数据的方式,不再以计算单元为中心,而以存储为中心,做成计算和存储一体的“存内计算”。

软件层面上的话,最近的很多优化,比如 Flash AttentionPaged Attention 都可以算。Flash Attention 就是减少了计算 Softmax 时从 DRAM 内存读取数据次数,从而提高了效率。

Flash Attention算法背后的主要思想是分割输入,将它们从慢速HBM加载到快速SRAM,然后计算这些块的 attention 输出。在将每个块的输出相加之前,将其按正确的归一化因子进行缩放,从而得到正确的结果。

vLLM 主要用于快速 LLM 推理和服务,其核心是Paged Attention同样,MQA 也是一个软件层面上翻墙的一个方法。这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,Paged Attention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说,Paged Attention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,Paged Attention 内核可以有效地识别和获取这些块。

MHA 到 MQA 到 GQA

MQA 的方法很简单,难的是看到这样的方法后,能立刻想到它为什么好。

一起看看 MQA 和 GQA 是怎么来的。

首先是原始的 MHA(Multi-Head Attention),QKV 三部分有相同数量的头,且一一对应。每次做 Attention,head1 的 QKV 就做好自己运算就可以,输出时各个头加起来就行。

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# 为了方便阅读,我们只保留了 llm-foundry 中关键部分的代码,完整代码请参照源码。
class MultiheadAttention(nn.Module):
def __init__(self,d_model: int,n_heads: int,device: str):
"""
Multi Head init func.
Args:
d_model (int): hidden state size, e.g. 768
n_heads (int): 设定的注意力头数, e.g. 8
device (str): _description_
"""
super().__init__()
self.d_model = d_model
self.n_heads = n_heads

self.Wqkv = nn.Linear( # 【关键】Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
device=device
) # (d_model, 3 * d_model)
self.attn_fn = scaled_multihead_dot_product_attention
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)

def forward(self,x):
"""
forward func.
Args:
x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
Returns:
_type_: _description_
"""
qkv = self.Wqkv(x) # (1, 768, 3 * 768)
query, key, value = qkv.chunk( # 【关键】每个 tensor 都是 (1, 512, 768)
3, dim=2)
context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads) # (1, 512, 768)
return self.out_proj(context), attn_weights, past_key_value

而 MQA(Multi-Query Attention) 则是,让 Q 仍然保持原来的头数,但 K 和 V 只有一个头,相当于所有的 Q 头共享一组 K 和 V 头,所以叫做 Multi-Query 了。实现改变了会不会影响效果呢?确实会影响但相对它能带来的收益,性能的些微降低是可以接受的。

从上图表中可以看到,MQA 在 encoder 上的提速没有非常明显,但在 decoder 上的提速是很显著的,能带来多大的收益呢,实验发现一般能提高 30%-40% 的吞吐。

收益主要就是由降低了 KV cache 带来的。实际上 MQA 运算量和 MHA 是差不多的,可理解为读取一组 KV 头之后,给所有 Q 头用,但因为之前提到的内存和计算的不对称,所以是有利的。

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(self,d_model: int,n_heads: int,device: Optional[str] = None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.Wqkv = nn.Linear( # 【关键】Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_model
device=device, # 而 key 和 value 则只共享各自的一个 head_dim 的向量
)
self.attn_fn = scaled_multihead_dot_product_attention
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
def forward(self,x):
qkv = self.Wqkv(x) # (1, 512, 960)
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
)
context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads,multiquery=True)
return self.out_proj(context), attn_weights, past_key_value

从上面的代码中可以看到,MHA 和 MQA 之间的区别只在于建立 Wqkv Layer 上:

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Multi Head Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
device=device
)
query, key, value = qkv.chunk( # 【关键】每个 tensor 都是 (1, 512, 768)
3,
dim=2
)
# Multi Query Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_model
device=device, # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
)

而 GQA(Grouped-Query Attention) 呢,是 MHA 和 MQA 的折衷方案,既不想损失性能太多,又想获得 MQA 带来的推理加速好处。具体思想是,不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上面图片就是两组 Q 共享一组 KV。

LLAMA2 中给出了效果对比,可以看到相比起 MQA,GQA的指标看起来还是要好些的。

同时在推理上的加速还和 MQA 类似:

MQA 和 GQA 形式在推理加速方面,主要是通过两方面来完成:

  • 降低了从内存中读取的数据量,所以也就减少了计算单元等待时间,提高了计算利用率;
  • KV cache 变小了 head_num 倍,也就是显存中需要保存的 tensor 变小了,空出来空间就可以加大 batch size,从而又能提高利用率。

如果要用 MQA 和 GQA,可以是从头训练的时候就加上,也可以像 GQA 论文里面一样,用已有的开源模型,挑一些头取个 mean 用来初始化 MQA 或 GQA 继续训练一段时间。

下面是 MQA 推导过程,不感兴趣同学可跳过,感兴趣同学可推一下,理解更透彻。

MQA 的推导

正如在 memory wall 中提到的,现在内存读取相对计算速度太慢导致拖后腿。

那么定义一个变量,, M 是 Memory 表示内存开销,而 A 是 Arithmetic 表示计算开销。如果这个值大于1的话,就会出现很明显的 Memory Wall,而当这个值小于1很多时,表示拿到数据后马上能开动马力计算,内存墙问题就不存在了。因为估算还有各种没考虑因素问题,所以即使等于 1 也不代表就能打满计算单元。

那么先来看看 MHA 下推理时每一个 timestep 这个值的大小,主要参考 MQA 原论文的简化:

PYTHON
1
2
3
4
5
6
7
8
9
10
11
#三个投影矩阵分别为 P_q, P_k, P_v; 维度为 h(头数), a(隐层大小,等于hd), d(每个头大小)
#当前 timestep 输入为 x,维度为 b(batch大小), a
#K_prev, V_prev 为 KV cache的矩阵,维度为 b, h, m(之前的timestep数),d; m+1=n
q = torch.einsum('ba,had->bhd', x, P_q) #M:had+ba, A:ba^2
k = torch.einsum('ba,had->bhd', x, P_k) #M:had+ba, A:ba^2
v = torch.einsum('ba,had->bhd', x, P_v) #M:had+ba, A:ba^2
K = torch.cat([K_prev, k.unsqueeze(2)], dim=-2) #M:bhnd+bhd, A:0
V = torch.cat([V_prev, v.unsqueeze(2)], dim=-2) #M:bhnd+bhd, A:0
logits = torch.einsum("bhd,bhnd->bhn", q, K)#M:bhnd+bhd, A:bhnd
weights = torch.softmax(logits/math.sqrt(d), dim=-1)#M:bhn
outs = torch.einsum("bhn,bhnd->bhd", weights, V)#M:bhn+bhnd, A:bhnd

所以对于 M 来说是

3(+)+4(+)3(2+)+4(+)32+4+7(+2)

对于 A 来说

32+232+2

假设隐层大小和 timestep 数接近,, 那么 A 就是(2) , 因此

=(+22)=(+1)

可以看到要想让这个比例小,可以增大b,也就是增大 batch size,现在推理优化就会将用户的请求收集成 batch 推理,提高利用率。同时前面提到,MQA 可以降低显存使用扩大 batch size,所以能提高一定利用率。

根据假设 ,这个比例会接近 1,会导致一定 Memory Wall,如果 n 很长的话问题就更明显。

而 MQA 的情况下

PYTHON
1
2
3
4
5
6
7
8
9
10
#投影矩阵 P_k, P_v 维度变为 a(隐层大小,等于hd), d(每个头大小)
#K_prev, V_prev 为 KV cache的矩阵,维度为 b, m(之前的timestep数),d; m+1=n
q = torch.einsum('ba,had->bhd', x, P_q) #M:had+ba, A:ba^2
k = torch.einsum('ba,ad->bd', x, P_k) #M:ad+ba, A:bad
v = torch.einsum('ba,ad->bd', x, P_v) #M:ad+ba, A:bad
K = torch.cat([K_prev, k.unsqueeze(1)], dim=-2) #M:bnd+bd
V = torch.cat([V_prev, v.unsqueeze(1)], dim=-2) #M:bnd+bd
logits = torch.einsum("bhd,bnd->bhn", q, K)#M:bhd+bnd, A:bhnd
weights = torch.softmax(logits/math.sqrt(d), dim=-1)#M:bhn
outs = torch.einsum("bhn,bnd->bhd", weights, V)#M:bhn+bnd, A:bhnd

会发现 A 整体来说没有变,如之前说的只是共享了 KV, 计算量还是一样的(2) ,M 变化比较大

(+)+2(+)+2(+)+(+)++(+)2+4+2+2+4+2(+2++)

于是系数为

=(+2++2)=(1+1++)

其中后面两项,d 一般比 h 要大,所以可以主要考虑  项。可看到之前占大头的  在分母加了个系数 h,这样就能降低  从而提高效率。

感兴趣的话,可自己推导一下 GQA 的情况,其中() 的分母中会加入一个数 , 其中 g 为 group 数,如果 g 为 1 的情况那就和 MQA 一样了,这块开销主要就有 g 来调整了。

再见美好旧时光

看到这,大概也能明白为什么要用 MQA 了,以及为什么 MQA 最近才突然火起来。

主要就是因为大规模 GPT 式生成模型的落地需求导致的

而在以前根本不需要关心这些,LSTM 只用维护一个状态,不存在要保留 Cache 什么。

到了 Transformer 提出后,虽然最早 Transformer 提出时是用在 Seq2Seq 任务上,也就是 Encoder 和 Decoder 都用,但可能模型量级不大,也没有太多落地需求,所以没引起太大关注。之后火了两年的 BERT 又是 Encoder 结构,直接前向一把过。

也只有到最近 GPT 大模型得到广泛应用时,才发现推理的这个瓶颈,于是大家翻出几年前的 trick,应用起来,发现非常好用。

同样原因,GPT 推理加速这块最近引起很多关注,大家都在想各种方法来提高推理效率。Huggingface 这两天也给 text-generation-inference 库的 license 给改了,应该也是想用这个挣点钱。

参考文献

  1. Fast Transformer Decoding: One Write-Head is All You Need

  2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

  3. Llama 2: Open Foundation and Fine-Tuned Chat Models

  4. llm-foundry

  5. 原文



Share Http URL:  http://www.wittx.cn/get_news_message.do?new_id=1238














请输入评论





























Best Last Month

欧洲疫情失控,法国进入紧急状态,英国或将再次全国封锁



恒大物业引入235亿港元战略投资

恒大物业引入235亿港元战略投资

Information industry

by wittx


Temporal Difference Learning

Temporal Difference Learning

Computer software and hardware

by wittx


A Self Driving Car Company Bets on Mall Shuttles



全球比特币ATM安装总数达9187台

全球比特币ATM安装总数达9187台

Information industry

by wittx


电池的效率更上一层楼——近一个月顶刊速递



中国人工智能发展报告

中国人工智能发展报告

Information industry

by show


如何使用微软Bing ChatGPT

如何使用微软Bing ChatGPT

Information industry

by wittx


忆阻器原理与发展现状

忆阻器原理与发展现状

Electronic electrician

by wittx


多传感器数据融合算法

多传感器数据融合算法

Information industry

by wittx