操作

mlx.nn.Module.apply

mlx.nn.Module.apply_to_modules

mlx.nn.Module.children

mlx.nn.Module.eval

mlx.nn.Module.filter_and_map

mlx.nn.Module.freeze

mlx.nn.Module.leaf_modules

mlx.nn.Module.load_weights

mlx.nn.Module.modules

mlx.nn.Module.named_modules

mlx.nn.Module.parameters

mlx.nn.Module.save_weights

mlx.nn.Module.set_dtype

mlx.nn.Module.train

mlx.nn.Module.trainable_parameters

mlx.nn.Module.unfreeze

mlx.nn.Module.update

mlx.nn.Module.update_modules

mlx.nn.ALiBi

mlx.nn.AvgPool1d

mlx.nn.AvgPool2d

mlx.nn.AvgPool3d

mlx.nn.BatchNorm

mlx.nn.CELU

mlx.nn.Conv1d

mlx.nn.Conv2d

mlx.nn.Conv3d

mlx.nn.ConvTranspose1d

mlx.nn.ConvTranspose2d

mlx.nn.ConvTranspose3d

mlx.nn.Dropout

mlx.nn.Dropout2d

mlx.nn.Dropout3d

mlx.nn.Embedding

mlx.nn.ELU

mlx.nn.GELU

mlx.nn.GLU

mlx.nn.GroupNorm

mlx.nn.GRU

mlx.nn.HardShrink

mlx.nn.HardTanh

mlx.nn.Hardswish

mlx.nn.InstanceNorm

mlx.nn.LayerNorm

mlx.nn.LeakyReLU

mlx.nn.Linear

mlx.nn.LogSigmoid

mlx.nn.LogSoftmax

mlx.nn.LSTM

mlx.nn.MaxPool1d

mlx.nn.MaxPool2d

mlx.nn.MaxPool3d

mlx.nn.Mish

mlx.nn.MultiHeadAttention

mlx.nn.PReLU

mlx.nn.QuantizedEmbedding

mlx.nn.QuantizedLinear

mlx.nn.RMSNorm

mlx.nn.ReLU

mlx.nn.ReLU6

mlx.nn.RNN

mlx.nn.RoPE

mlx.nn.SELU

mlx.nn.Sequential

mlx.nn.Sigmoid

mlx.nn.SiLU

mlx.nn.SinusoidalPositionalEncoding

mlx.nn.Softmin

mlx.nn.Softshrink

mlx.nn.Softsign

mlx.nn.Softmax

mlx.nn.Softsign

mlx.nn.Softplus

mlx.nn.Step

mlx.nn.Tanh

mlx.nn.Transformer

mlx.nn.Upsample

函数

mlx.nn.elu

mlx.nn.celu

mlx.nn.gelu

mlx.nn.gelu_approx

mlx.nn.gelu_fast_approx

mlx.nn.glu

mlx.nn.hard_shrink

mlx.nn.hard_tanh

mlx.nn.hardswish

mlx.nn.leaky_relu

mlx.nn.log_sigmoid

mlx.nn.log_softmax

mlx.nn.mish

mlx.nn.prelu

mlx.nn.relu

mlx.nn.relu6

mlx.nn.selu

mlx.nn.sigmoid

mlx.nn.silu

mlx.nn.softmax

mlx.nn.softmin

mlx.nn.softplus

mlx.nn.softshrink

mlx.nn.step

mlx.nn.tanh

损失函数

mlx.nn.losses.binary_cross_entropy

mlx.nn.losses.cosine_similarity_loss

mlx.nn.losses.cross_entropy

mlx.nn.losses.gaussian_nll_loss

mlx.nn.losses.hinge_loss

mlx.nn.losses.huber_loss

mlx.nn.losses.kl_div_loss

mlx.nn.losses.l1_loss

mlx.nn.losses.log_cosh_loss

mlx.nn.losses.margin_ranking_loss

mlx.nn.losses.mse_loss

mlx.nn.losses.nll_loss

mlx.nn.losses.smooth_l1_loss

mlx.nn.losses.triplet_loss

初始化器

mlx.nn.init.constant

mlx.nn.init.normal

mlx.nn.init.uniform

mlx.nn.init.identity

mlx.nn.init.glorot_normal

mlx.nn.init.glorot_uniform

mlx.nn.init.he_normal

mlx.nn.init.he_uniform

优化器

优化器

mlx.optimizers.Optimizer.state

mlx.optimizers.Optimizer.apply_gradients

mlx.optimizers.Optimizer.init

mlx.optimizers.Optimizer.update

常用优化器

mlx.optimizers.SGD

mlx.optimizers.RMSprop

mlx.optimizers.Adagrad

mlx.optimizers.Adafactor

mlx.optimizers.AdaDelta

mlx.optimizers.Adam

mlx.optimizers.AdamW

mlx.optimizers.Adamax

mlx.optimizers.Lion

mlx.optimizers.MultiOptimizer

调度器

mlx.optimizers.cosine_decay

mlx.optimizers.exponential_decay

mlx.optimizers.join_schedules

mlx.optimizers.linear_schedule

mlx.optimizers.step_decay

mlx.optimizers.clip_grad_norm

mlx.core.distributed.Group

mlx.core.distributed.is_available

mlx.core.distributed.init

mlx.core.distributed.all_sum

mlx.core.distributed.all_gather

mlx.core.distributed.send

mlx.core.distributed.recv

mlx.core.distributed.recv_like

树工具

mlx.utils.tree_flatten

mlx.utils.tree_unflatten

mlx.utils.tree_map

mlx.utils.tree_map_with_path

mlx.utils.tree_reduce

C++ API 参考

进一步阅读

MLX 中的自定义扩展

Metal 调试器

自定义 Metal 内核

在 C++ 中使用 MLX

.rst

.pdf

操作#

abs(a, /, *[, stream])

逐元素的绝对值。

add(a, b[, stream])

逐元素相加。

addmm(c, a, b, /[, alpha, beta, stream])

带加法和可选缩放的矩阵乘法。

all(a, /[, axis, keepdims, stream])

沿给定轴的“与”归约。

allclose(a, b, /[, rtol, atol, equal_nan, ...])

两个数组的近似比较。

any(a, /[, axis, keepdims, stream])

沿给定轴的“或”归约。

arange(-> array)

重载函数。

arccos(a, /, *[, stream])

逐元素的反余弦。

arccosh(a, /, *[, stream])

逐元素的反双曲余弦。

arcsin(a, /, *[, stream])

逐元素的反正弦。

arcsinh(a, /, *[, stream])

逐元素的反双曲正弦。

arctan(a, /, *[, stream])

逐元素的反正切。

arctan2(a, b, /, *[, stream])

逐元素的两个数组比值的反正切。

arctanh(a, /, *[, stream])

逐元素的反双曲正切。

argmax(a, /[, axis, keepdims, stream])

沿轴的最大值索引。

argmin(a, /[, axis, keepdims, stream])

沿轴的最小值索引。

argpartition(a, /, kth[, axis, stream])

返回分割数组的索引。

argsort(a, /[, axis, stream])

返回排序数组的索引。

array_equal(a, b[, equal_nan, stream])

数组相等性检查。

as_strided(a, /[, shape, strides, offset, ...])

创建具有给定形状和步幅的数组视图。

atleast_1d(*arys[, stream])

将所有数组转换为至少一个维度。

atleast_2d(*arys[, stream])

将所有数组转换为至少两个维度。

atleast_3d(*arys[, stream])

将所有数组转换为至少三个维度。

bitwise_and(a, b[, stream])

逐元素的按位与。

bitwise_invert(a[, stream])

逐元素的按位反转。

bitwise_or(a, b[, stream])

逐元素的按位或。

bitwise_xor(a, b[, stream])

逐元素的按位异或。

block_masked_mm(a, b, /[, block_size, ...])

带块掩码的矩阵乘法。

broadcast_arrays(*arrays[, stream])

将数组相互广播。

broadcast_to(a, /, shape, *[, stream])

将数组广播到给定形状。

ceil(a, /, *[, stream])

逐元素的向上取整 (ceil)。

clip(a, /, a_min, a_max, *[, stream])

将数组值限制在给定最小值和最大值之间。

concatenate(arrays[, axis, stream])

沿给定轴连接数组。

contiguous(a, /[, allow_col_major, stream])

强制数组行为行连续。

conj(a, *[, stream])

返回输入的逐元素的复共轭。

conjugate(a, *[, stream])

convolve(a, v, /[, mode, stream])

1D 数组的离散卷积。

conv1d(input, weight, /[, stride, padding, ...])

对具有多个通道的输入进行 1D 卷积

conv2d(input, weight, /[, stride, padding, ...])

对具有多个通道的输入进行 2D 卷积

conv3d(input, weight, /[, stride, padding, ...])

对具有多个通道的输入进行 3D 卷积

conv_transpose1d(input, weight, /[, stride, ...])

对具有多个通道的输入进行 1D 转置卷积

conv_transpose2d(input, weight, /[, stride, ...])

对具有多个通道的输入进行 2D 转置卷积

conv_transpose3d(input, weight, /[, stride, ...])

对具有多个通道的输入进行 3D 转置卷积

conv_general(input, weight, /[, stride, ...])

对具有多个通道的输入进行通用卷积

cos(a, /, *[, stream])

逐元素的余弦。

cosh(a, /, *[, stream])

逐元素的双曲余弦。

cummax(a, /[, axis, reverse, inclusive, stream])

返回沿给定轴的元素的累积最大值。

cummin(a, /[, axis, reverse, inclusive, stream])

返回沿给定轴的元素的累积最小值。

cumprod(a, /[, axis, reverse, inclusive, stream])

返回沿给定轴的元素的累积乘积。

cumsum(a, /[, axis, reverse, inclusive, stream])

返回沿给定轴的元素的累积和。

degrees(a, /, *[, stream])

将角度从弧度转换为度。

dequantize(w, /, scales, biases[, ...])

使用提供的 scalesbiases 以及 group_sizebits 配置,对矩阵 w 进行解量化。

diag(a, /[, k, stream])

提取对角线或构建对角矩阵。

diagonal(a[, offset, axis1, axis2, stream])

返回指定的对角线。

divide(a, b[, stream])

逐元素相除。

divmod(a, b[, stream])

逐元素的商和余数。

einsum(subscripts, *operands[, stream])

对操作数执行爱因斯坦求和约定。

einsum_path(subscripts, *operands)

计算给定爱因斯坦求和的收缩顺序。

equal(a, b[, stream])

逐元素相等。

erf(a, /, *[, stream])

逐元素的误差函数。

erfinv(a, /, *[, stream])

逐元素的 erf() 的逆。

exp(a, /, *[, stream])

逐元素的指数。

expm1(a, /, *[, stream])

逐元素的指数减去 1。

expand_dims(a, /, axis, *[, stream])

在给定轴添加大小为一的维度。

eye(n[, m, k, dtype, stream])

创建单位矩阵或通用对角矩阵。

flatten(a, /[, start_axis, end_axis, stream])

展平数组。

floor(a, /, *[, stream])

逐元素的向下取整 (floor)。

floor_divide(a, b[, stream])

逐元素的整数相除。

full(shape, vals[, dtype, stream])

构造具有给定值的数组。

gather_mm(a, b, /, lhs_indices, rhs_indices, *)

带矩阵级 gather 的矩阵乘法。

gather_qmm(x, w, /, scales, biases[, ...])

执行带矩阵级 gather 的量化矩阵乘法。

greater(a, b[, stream])

逐元素的 greater than。

greater_equal(a, b[, stream])

逐元素的 greater or equal。

hadamard_transform(a[, scale, stream])

沿最终轴执行 Walsh-Hadamard 变换。

identity(n[, dtype, stream])

创建方单位矩阵。

imag(a, /, *[, stream])

返回复数数组的虚部。

inner(a, b, /, *[, stream])

1-D 数组的普通向量内积,在高维中是最后轴上的求和积。

isfinite(a[, stream])

返回一个布尔数组,指示哪些元素是有限的。

isclose(a, b, /[, rtol, atol, equal_nan, stream])

返回一个布尔数组,其中两个数组在容差范围内逐元素相等。

isinf(a[, stream])

返回一个布尔数组,指示哪些元素是 +/- 无穷大。

isnan(a[, stream])

返回一个布尔数组,指示哪些元素是 NaN。

isneginf(a[, stream])

返回一个布尔数组,指示哪些元素是负无穷大。

isposinf(a[, stream])

返回一个布尔数组,指示哪些元素是正无穷大。

kron(a, b, *[, stream])

计算两个数组 ab 的 Kronecker 积。

left_shift(a, b[, stream])

逐元素的左移。

less(a, b[, stream])

逐元素的小于。

less_equal(a, b[, stream])

逐元素的小于或等于。