mlx.core.hadamard_transform

mlx.core.hadamard_transform#

hadamard_transform(a: array, scale: float | None = None, stream: None | Stream | Device = None) array#

沿最终轴执行 Walsh-Hadamard 变换。

等价于

from scipy.linalg import hadamard

y = (hadamard(len(x)) @ x) * scale

支持大小为 n = m*2^k 的变换,其中 m(1, 12, 20, 28) 中,float32 类型支持 2^k <= 8192,float16/bfloat16 类型支持 2^k <= 16384

参数:
  • a (array) – 输入数组或标量。

  • scale (float) – 用于缩放输出的因子。默认为 1/sqrt(a.shape[-1]),以便 Hadamard 矩阵正交。

返回值:

变换后的数组。

返回类型:

array