mlx.nn.Module.set_dtype

mlx.nn.Module.set_dtype#

Module.set_dtype(dtype: ~mlx.core.Dtype, predicate: ~typing.Callable[[~mlx.core.Dtype], bool] | None = <function Module.<lambda>>)#

设置模块参数的数据类型(dtype)。

参数:
  • dtype (Dtype) – 新的数据类型(dtype)。

  • predicate (Callable, 可选) – 用于选择要进行类型转换参数的谓词。默认情况下,只有类型为 floating 的参数会被更新,以避免将整型参数转换为新的数据类型(dtype)。