mlx.nn.init.identity

目录

mlx.nn.init.identity#

identity(dtype: Dtype = mlx.core.float32) Callable[[array], array]#

一个返回单位矩阵的初始化器。

参数:

dtype (Dtype, 可选) – 数组的数据类型。默认值:float32

返回值:

一个初始化器,返回一个与输入具有相同形状的单位矩阵。

返回类型:

Callable[[array], array]

示例

>>> init_fn = nn.init.identity()
>>> init_fn(mx.zeros((2, 2)))
array([[1, 0],
       [0, 1]], dtype=float32)