保存和加载数组

保存和加载数组#

MLX 支持多种数组序列化格式。

序列化格式#

格式

扩展名

函数

备注

NumPy

.npy

save()

仅支持单个数组

NumPy 归档

.npz

savez()savez_compressed()

多个数组

Safetensors

.safetensors

save_safetensors()

多个数组

GGUF

.gguf

save_gguf()

多个数组

load() 函数可以加载任何支持的序列化格式。它根据文件扩展名确定格式。load() 的输出取决于格式。

以下是将单个数组保存到文件的示例

>>> a = mx.array([1.0])
>>> mx.save("array", a)

数组 a 将被保存到文件 array.npy 中(请注意扩展名是自动添加的)。包含扩展名是可选的;如果省略,则会自动添加。您可以使用以下方式加载数组

>>> mx.load("array.npy")
array([1], dtype=float32)

以下是将多个数组保存到单个文件的示例

>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)

为了与 numpy.savez() 兼容,MLX 的 savez() 将数组作为参数。如果省略关键字参数,则会提供默认名称。可以使用以下方式加载:

>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}

在这种情况下,load() 返回一个将名称映射到数组的字典。

函数 save_safetensors()save_gguf()savez() 类似,但它们接受一个将字符串名称映射到数组的 dict 作为输入。

>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})