保存和加载数组#
MLX 支持多种数组序列化格式。
格式 |
扩展名 |
函数 |
备注 |
|---|---|---|---|
NumPy |
|
仅支持单个数组 |
|
NumPy 归档 |
|
多个数组 |
|
Safetensors |
|
多个数组 |
|
GGUF |
|
多个数组 |
load() 函数可以加载任何支持的序列化格式。它根据文件扩展名确定格式。load() 的输出取决于格式。
以下是将单个数组保存到文件的示例
>>> a = mx.array([1.0])
>>> mx.save("array", a)
数组 a 将被保存到文件 array.npy 中(请注意扩展名是自动添加的)。包含扩展名是可选的;如果省略,则会自动添加。您可以使用以下方式加载数组
>>> mx.load("array.npy")
array([1], dtype=float32)
以下是将多个数组保存到单个文件的示例
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)
为了与 numpy.savez() 兼容,MLX 的 savez() 将数组作为参数。如果省略关键字参数,则会提供默认名称。可以使用以下方式加载:
>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
在这种情况下,load() 返回一个将名称映射到数组的字典。
函数 save_safetensors() 和 save_gguf() 与 savez() 类似,但它们接受一个将字符串名称映射到数组的 dict 作为输入。
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})