mlx.core.conv_general

mlx.core.conv_general#

conv_general(input: array, weight: array, /, stride: int | Sequence[int] = 1, padding: int | Sequence[int] | tuple[Sequence[int], Sequence[int]] = 0, kernel_dilation: int | Sequence[int] = 1, input_dilation: int | Sequence[int] = 1, groups: int = 1, flip: bool = False, *, stream: None | Stream | Device = None) array#

对具有多个通道的输入执行通用卷积

参数:
  • input (array) – 输入数组,形状为 (N, ..., C_in)

  • weight (array) – 权重数组,形状为 (C_out, ..., C_in)

  • stride (int or list(int), optional) – 列表,包含核步长。如果只指定一个数字,则所有空间维度使用相同的步长。默认值: 1

  • padding (int, list(int), or tuple(list(int), list(int)), optional) – 列表,包含输入填充。如果只指定一个数字,则所有空间维度使用相同的填充。默认值: 0

  • kernel_dilation (int or list(int), optional) – 列表,包含核膨胀率。如果只指定一个数字,则所有空间维度使用相同的膨胀率。默认值: 1

  • input_dilation (int or list(int), optional) – 列表,包含输入膨胀率。如果只指定一个数字,则所有空间维度使用相同的膨胀率。默认值: 1

  • groups (int, optional) – 输入特征组数。默认值: 1

  • flip (bool, optional) – 翻转权重空间维度的处理顺序。flipFalse 时执行互相关运算,否则执行卷积运算。默认值: False

返回值:

卷积后的数组。

返回类型:

array