不直接调用 nn.Transformer,而是手写位置编码、多头注意力、Encoder / Decoder,并用一个反转序列的 toy task 跑通训练与解码。
这一篇主要整理自
pytorch_using/transformer.py。这份代码的价值不在于"重新发明一个工业级 Transformer",而在于把 Transformer 拆成可验证、可训练、可调试的模块。
1. 为什么要手写一版 Transformer
直接用 nn.Transformer 当然更快,但我自己一直觉得,想真的理解 Transformer,至少要完整看一遍这些模块是怎么拼起来的:
- 位置编码
- padding mask
- causal mask
- scaled dot-product attention
- multi-head attention
- feed-forward
- encoder / decoder layer
把这些都走通一遍之后,再回去看高级封装,心里会稳很多。
2. 这份实现统一采用 batch_first
代码一开头就明确了形状约定:
token id: [B, S]embedding 后: [B, S, D]这个约定非常好,因为后面所有 shape 变化都能围绕它来理解。
3. 位置编码:让模型知道"顺序"
这份实现里的 PositionalEncoding 很标准:
Python3
点击展开代码
展开代码
这段最关键的是:
- 输入输出都保持
[B, S, D] - 位置向量不是参数,而是 buffer
- 它解决的是"Attention 本身不带顺序感"的问题
4. Mask:谁该被遮住
这份代码把两种最重要的 mask 都单独实现了:
4.1 Padding Mask
Python3
点击展开代码
展开代码
它解决的是:
补齐出来的 PAD 不应该参与有效注意力。
4.2 Causal Mask
Python3
点击展开代码
展开代码
它解决的是:
Decoder 在生成当前 token 时,不能偷看未来位置。
5. Attention 的核心主线
这一段我特别喜欢原代码里写的复习清单,因为它几乎就是最短背诵版:
QK^T/ sqrt(d_k)masked_fillsoftmax@ V真正实现就是:
Python3
点击展开代码
展开代码
这一段如果 shape 能看懂,Transformer 就已经通了一半。
6. Multi-Head Attention 真正增加了什么
多头注意力的重点,不只是"多做几次 attention",而是:
- 先把同一个表示投影到不同子空间
- 每个头学不同的关注模式
- 最后再拼回来
原代码里把这条 shape 变化写得很清楚:
[B, S, D]→ [B, S, H, Dh]→ [B, H, S, Dh]→ attention→ [B, S, D]这是理解多头机制最值得反复看的地方。
7. Encoder / Decoder 是怎么组起来的
这份实现保持了 Transformer 最经典的结构:
EncoderLayer
- self-attention
- residual + layer norm
- FFN
- residual + layer norm
DecoderLayer
- masked self-attention
- cross-attention
- FFN
- 每段后都有 residual + layer norm
这时候 Transformer 就不再神秘了,它就是把这些标准模块一层层堆起来。
8. 用 toy task 跑通:反转序列
我很喜欢这份代码没有直接上复杂任务,而是先做了一个最小的可验证任务:
把输入序列反转。
数据构造函数也写得很清楚:
Python3
点击展开代码
展开代码
这段非常适合理解 seq2seq 训练里的两个关键点:
tgt_input和tgt_output是错位的- Decoder 训练时吃的是前一个位置的真实 token
9. 训练循环和 greedy decode
最后这段代码把完整流程跑通了:
- 训练时:
src -> encodertgt_input -> decoderlogits -> CrossEntropyLoss
- 推理时:
- 从
BOS开始 - 每次取最后一个位置的 logits
- 贪心生成下一个 token
- 从
这就是最小版的 seq2seq 生成闭环。
10. 这一阶段该记住什么
如果只保留最少几句话:
- Transformer 不是黑盒,它是多个标准模块的组合。
- 位置编码、mask、多头注意力是最关键的三个部件。
- 理解 shape 变化,比死背公式更重要。
- 一个 toy task 足够把整条训练与推理链跑通。
我觉得这份手写实现最有价值的地方,不是"性能",而是它把 Transformer 变成了一套可以亲手拆开的积木。
专题阅读
Pytorch
这篇文章属于同一条阅读链。你可以直接在这里切换,不用再回到列表页重新找。
部分信息可能已经过时
留言区
留言
欢迎纠错、补充、交流。昵称和评论内容必填;如果你愿意,也可以留下联系方式,仅站主可见。