mlx.nn.Conv3d

目录

mlx.nn.Conv3d#

class Conv3d(in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, dilation: int | tuple = 1, bias: bool = True)#

在多通道输入图像上应用三维卷积。

通道预计位于最后,即输入形状应为 NDHWC,其中:

  • N 是批量维度

  • D 是输入图像深度

  • H 是输入图像高度

  • W 是输入图像宽度

  • C 是输入通道数

参数
  • in_channels (int) – 输入通道数。

  • out_channels (int) – 输出通道数。

  • kernel_size (int or tuple) – 卷积滤波器的大小。

  • stride (int or tuple, optional) – 应用滤波器时的步长大小。默认值: 1

  • dilation (int or tuple, optional) – 卷积的扩张(dilation)。

  • padding (int or tuple, optional) – 输入的零填充位置数。默认值: 0

  • bias (bool, optional) – 如果为 True,则在输出中添加可学习的偏置。默认值: True

方法