金山相勤网

Llama也中招,混合精度下位置编码竟有大坑,百川智能给出修复妄想

2024-11-18 18:27:56 来源:

位置编码技术是中招置编一种可能让神经收集建模句子中 Token 位信托息的技术  。在 Transformer 大行其道的混合时期 ,由于 Attention 妄想无奈建模每一个 token 的精度位信托息,位置编码(Position embedding) 成为 Transformer 颇为紧张的下位想一个组件 。钻研职员也提出了林林总总的码竟位置编码妄想来让收集建模位信托息,Rope 以及 Alibi 是坑百当初最被普遍接管的两种位置编码妄想 。

可是川智出修最近来自百川智能的钻研发现,Rope 以及 alibi 位置编码的复妄主流实如今低精度(特意是 bfloat16) 下存在位置编码碰撞的 bug, 这可能会影响模子的磨炼以及推理 。而且当初大部份主流开源模子的中招置编实现都存在该下场,连 llama 民间代码也中招了 。混合

还患上从位置编码算法提及

为了弄清晰这个下场,精度患上先从位置编码的下位想算法道理提及,在 Transformer 妄想中,码竟所有 Attention Block 的坑百输入都市先经由位置编码  ,再输入收集妨碍后续处置。川智出修隧道的 Attention 妄想是无奈精确感知到每一个 token 的位信托息的,而对于语言的良多使命来说,语句的挨次对于语义信息的影响黑白常大的,为了建模 token 之间的位置关连,Transfomer 原始论文中引入地位编码来建模位信托息。

图 1 - 施加 Positon Embedding 展现图	。

为了让模子更好地建模句子的位信托息,钻研职员提出了多种位置编码妄想,meta 开源的 llama [4] 模子接管了 Rope [5] 妄想,使患上 Rope 成为在开源社区被普遍接管的一种位置编码妄想 。而 Alibi 编码因其精采的外推性也被普遍运用。

清晰低精度下的位置编码碰撞以前,先往返忆一下相关算法道理 。

Sinusoidal 位置编码

这是 Transformer 原始论文中提出的位置编码措施  。它经由运用差距频率的正弦以及余弦函数来为每一个位置发生一个配合的编码。抉择三角函数来天生位置编码有两个精采的性子:

1)编码相对于位信托息 ,数学上可能证实 PE (pos+k) 可能被 PE (pos) 线性展现 , 这象征着位置编码中搜罗了相对于位信托息 。

图 2- 句子长度为 50 的位置编码
,编码维度 128�,每一行代表一个 Position Embedding	。

2)短途衰减 :差距位置的 position encoding 点乘服从会随着相对于位置的削减而递减 [1] 。

图 3 - 差距位置的位置编码点积可视化。

Rope

Rope 是当初开源社区运用最普遍的一种位置编码妄想, 经由相对于位置编码的方式实现相对于位置编码,在引入相对于位信托息的同时坚持了相对于位置编码的优势(不需要像相对于位置编码同样去操作 attention matrix)。令 f_q, f_k 为 位置编码的函数,m 展现位置 ,x_m 展现该位置 token 对于应的 embedding ,咱们愿望经由位置编码后的 embedding 点积仅以及相对于位置无关 ,则可能有公式:

下面公式中 g 是某个函数 ,展现内积的服从只以及 x_m 以及 x_n 的值,以及两者位置的相对于关连 (m-n) 无关在 2 维的情景下可能推导出(详细推导历程可参考原论文) :

由于矩阵乘法线性累加的性子 ,可能拓展到多维的情景可患上 :

为了引入短途衰减的特色,Rope 中 \theta 的选取抉择了 Transformer 原始论文中 sinusoidal 公式。

Alibi

Alibi 是google宣告在 ICLR2022 的一篇使命 ,Alibi 主要处置了位置编码外推下场差的痛点 ,算法脑子颇为的重大,而且颇为直不雅  。与直接加在 embedding 上的相对于位置编码差距 ,Alibi 的脑子是在 attention matrix 上施加一个与距离成正比的表彰偏置 ,表彰偏置随着相对于距离的削减而削减。在详细实现时,对于每一个 head 会有一个超参 m 来操作表彰偏置随着相对于距离削减的幅度(斜率)。

图 4 - Alibi attention bias 展现图

图 4 - Alibi attention bias 展现图

论文服从展现 Alibi 极大的提升了模子的外推功能 ,16k token 的输入依然可能很好的反对于。

图 5 - Alibi 外推下场比力。

混合精度下位置编码的 bug

从下面的算法道理中 ,不论是 rope 的 cos (m\theta) 仍是 alibi 的 i-1(m, i 代表 postion id), 需要为每一个位置天生一个整型的 position_id, 在高下文窗口比力大的时候 ,百川智能发现当初主流的位置编码实如今混合精度下都存在由于低精度(float16/bfloat16) 浮点数展现精度缺少导致位置编码碰撞的下场 。特意当模子磨炼(推理)时高下文长度越来越长,低精度展现带来的位置编码碰撞下场越来越严正,进而影响模子的下场  ,下面以 bfloat16 为例来剖析这个 bug。

浮点数展现精度

浮点数在合计机中展现由标志位(sign) ,指数位 (exponent) ,尾数位 (fraction) 三部份组成 ,对于一个老例的数值展现,可能由如下公式来合计其代表的数值(其中 offset 是指数位的偏置) :

由公式可知,尾数位的长度抉择了浮点数的展现精度 。深度学习中罕用的 float32/float16/bfloat16 内存中的展现分说如下图所示 :

图 6- bfloat16 的展现格式图 7- float16 的展现格式图 8- float32 的展现格式

可能看到 float16 以及 bfloat16 比照于 float32 都舍身了展现的精度,后续以 bfloat16 为例剖析位置编码中存在的下场(float16 同理)。下表揭示了 bfloat16 在差距数值规模(只截取整数部份)内的展现精度。

可能看到当整数规模逾越 256, bfloat16 就无奈精确展现每一个整数,可能用代码验证一下展现精度带来的下场。

Rope& Alibi 编码的下场

Meta 开源的 llama 模子接管了 Rope 的位置编码方式 , 民间的实现(以及大部份的第三方 llama 系列模子)在 bfloat16 下存在精度下场带来的位置编码碰撞(差距位置的 token 在 bfloat16  下酿成统一个数) 。Llama 民间代码如下 :

Python
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

下面第 18 行中间一句凭证输入序列长度天生每一个位置的 positon idx 在 bfloat16 下发生位置碰撞。

Python
# self.inv_freq.dtype == torch.bfloat16 when bfloat16 is enabled during training
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

在实际磨炼时假如开了 bfloat16, self.inv_freq 的 dtype 会被转为 bfloat16, 可能经由重大的代码来看一下位置碰撞的下场。

Python
t = torch.arange(4096, dtype=torch.float32)
plt.scatter(t[-100:], t[-100:].to(torch.bfloat16).float(),s=0.8)
plt.xlabel('position in float32')
plt.ylabel('position in bfloat16'

凭证 bfloa16 的展现精度可知 ,磨炼(推理)时高下文长度越长,位置编码碰撞的情景越严正 ,长度为 8192 的高下文推理中 ,仅有约莫 10% 的 token 位置编码是精确的,幸好位置编码碰撞有局域性的特质 ,惟独多少多个相邻的 token 才会同享统一个 position Embedding, 在更大的尺度上 ,差距位置的 token 仍是有确定的分说性 。

图 10- 差距高下文窗口下位置编码精确 token 所占比例�。

除了 llama 模子 ,百川智能发现 alibi 位置编码也存在上述下场 ,原因依然在于天生整数的位置索引时会在低精度下发生碰撞下场。

修复妄想

Rope 修复

Rope 的修复相对于重大 ,惟独要保障在天生 position_id 的时候确定在 float32 的精度上即可 。留意:

float32 的 tensor register_buffer 后在磨炼时假如开启了 bfloat16, 也会被转为 bfloat16。

Python
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Alibi 修复

  • alibi 位置编码修复思绪以及 Rope 的修复思绪不同 ,但由于 alibi 的 attention bias 直接加在 attention matrix 下面 ,假如凭证下面的修复思绪,attention matrix 的规范必需以及 attention bias 不同,导致全部 attention 的合计都在 float32 规范上合计,这会极大的拖慢磨炼速率

  • 当初主流的 attention 减速措施 flashattention 不反对于 attention bias 参数, 而 xformers 要求 attention  bias 规范必需与 query.dtype 相同 ,因此像 rope 那样重大的将 attention bias 规范提升到 float32 将会极大的拖慢磨炼速率

  • 针对于该下场百川智能提出了一种新的 alibi attention 妄想, 全部 attention bias 依然在 bfloat16 规范上  ,相似于 sinusiodal 的短途衰减特质 , 可能尽管纵然保障临近 token 位置编码的精确性,对于相对于距离过远的的 token 则可能容忍其发生确定的位置碰撞。原有的 alibi 实现则相同,相对于距离越远的 token 展现越精确 ,相对于距离越近的 token 则会碰撞

图 11- 修复先后 alibi attention_bias 比力	。

修复下场

百川智能仅在推理阶段对于位置编码的精度下场妨碍修复【注  :磨炼阶段可能也存在下场  ,取决于磨炼的详细配置装备部署以及措施】 ,可能看到:

a.在长高下文的推理中 ,模子的 ppl 要清晰优于修复前的 ppl

b.Benchmark 上测试服从展现修复先后差距不大 ,可能是由于 benchmark 上测试文本长度有限 ,很少触发 Position embedding 的碰撞

Benchmark 比力Benchmark 比力

Perplexity

咱们在通用的文本数据上对于更正先后模子在中英文文本上的怀疑度妨碍测试,下场如下:

[0] Dongxu Zhang, & Dong Wang. (2015). Relation Classification via Recurrent Neural Network.

[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, & Illia Polosukhin. (2023). Attention Is All You Need.

[2] Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V. Le, & Ruslan Salakhutdinov. (2019). Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.

[3] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, & Peter J. Liu. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.

[4] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, & Guillaume Lample. (2023). LLaMA: Open and Efficient Foundation Language Models.

[5] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, & Yunfeng Liu. (2022). RoFormer: Enhanced Transformer with Rotary Position Embedding.

[6] Ofir Press, Noah A. Smith, & Mike Lewis. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

[7] Yutao Sun, Li Dong, Barun Patra, Shuming Ma, Shaohan Huang, Alon Benhaim, Vishrav Chaudhary, Xia Song, & Furu Wei. (2022). A Length-Extrapolatable Transformer.

[8]  https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

[9] Shouyuan Chen, Sherman Wong, Liangjian Chen, & Yuandong Tian. (2023). Extending Context Window of Large Language Models via Positional Interpolation.

[10] https://www.reddit.com/r/LocalLLaMA/co妹妹ents/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/