mlx.nn.Linear

目录

mlx.nn.Linear#

class Linear(input_dims: int, output_dims: int, bias: bool = True)#

对输入应用仿射变换。

具体来说

\[y = x W^\top + b\]

其中:\(W\) 的形状为 [output_dims, input_dims]\(b\) 的形状为 [output_dims]

值从均匀分布 \(\mathcal{U}(-{k}, {k})\) 初始化,其中 \(k = \frac{1}{\sqrt{D_i}}\)\(D_i\) 等于 input_dims

参数:
  • input_dims (int) – 输入特征的维度

  • output_dims (int) – 输出特征的维度

  • bias (bool, optional) – 如果设为 False,则该层不使用偏置项。默认为 True

方法

to_quantized([group_size, bits])

返回一个近似此层的 QuantizedLinear 层。