mlx.nn.Transformer#
- class Transformer(dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: int | None = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = <mlx.gc_func object>, custom_encoder: ~typing.Any | None = None, custom_decoder: ~typing.Any | None = None, norm_first: bool = True, checkpoint: bool = False)#
实现了一个标准的 Transformer 模型。
该实现基于Attention Is All You Need。
Transformer 模型包含一个编码器和一个解码器。编码器处理输入序列,解码器生成输出序列。编码器和解码器之间的交互通过注意力机制实现。
- 参数:
dims (int, 可选) – 编码器/解码器输入中期望的特征数量。默认值:
512
。num_heads (int, 可选) – 注意力头的数量。默认值:
8
。num_encoder_layers (int, 可选) – Transformer 编码器中的编码器层数。默认值:
6
。num_decoder_layers (int, 可选) – Transformer 解码器中的解码器层数。默认值:
6
。mlp_dims (int, 可选) – 每个 Transformer 层中 MLP 块的隐藏维度。如果未提供,默认为
4*dims
。默认值:None
。dropout (float, 可选) – Transformer 编码器和解码器的 dropout 值。dropout 在每个注意力层和 MLP 层的激活函数后使用。默认值:
0.0
。activation (函数, 可选) – MLP 隐藏层的激活函数。默认值:
mlx.nn.relu()
。custom_encoder (Module, 可选) – 用于替换标准 Transformer 编码器的自定义编码器。默认值:
None
。custom_decoder (Module, 可选) – 用于替换标准 Transformer 解码器的自定义解码器。默认值:
None
。norm_first (bool, 可选) – 如果为
True
,编码器和解码器层将在注意力和 MLP 操作之前执行层归一化,否则在其之后执行。默认值:True
。checkpoint (bool, 可选) – 如果为
True
,则执行梯度检查点以减少内存使用,但会增加计算量。默认值:False
。
方法