mlx.nn.quantize#
- quantize(model: Module, group_size: int = 64, bits: int = 4, class_predicate: Callable[[str, Module], bool | dict] | None = None)#
根据谓词对模块的子模块进行量化。
默认情况下,所有定义了
to_quantized(group_size, bits)
方法的层都会被量化。包括Linear
和Embedding
层。另请注意,模块会原地更新。- 参数:
model (Module) – 可能被量化的模型叶子模块。
group_size (int) – 量化分组大小 (参见
mlx.core.quantize()
)。默认值:64
。bits (int) – 每个参数的位数 (参见
mlx.core.quantize()
)。默认值:4
。class_predicate (Optional[Callable]) – 一个可调用对象,接收
Module
路径和Module
本身,如果应该量化,则返回True
或一个包含 to_quantized 参数的字典,否则返回False
。如果为None
,则所有定义了to_quantized(group_size, bits)
方法的层都会被量化。默认值:None
。