mlx.nn.Upsample#
- class Upsample(scale_factor: float | Tuple, mode: Literal['nearest', 'linear', 'cubic'] = 'nearest', align_corners: bool = False)#
对输入信号进行空间上采样。
按照惯例,空间维度是维度
1
到x.ndim - 2
。第一个是批量维度,最后一个是特征维度。例如,音频信号是具有 1 个空间维度的 3D 信号,图像是具有 2 个空间维度的 4D 信号,依此类推。
实现了三种上采样算法:最近邻上采样、线性插值和三次插值。所有算法都可以应用于任意数量的空间维度。当应用于多个空间维度时,线性插值将是双线性、三线性等。当有两个空间维度时,三次插值将是双三次插值。
注意
当使用线性或三次插值模式之一时,
align_corners
参数会改变输入图像中角的处理方式。如果align_corners=True
,则输入和输出的上边缘和左边缘将匹配,右下边缘也将匹配。- 参数:
示例
>>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1)) >>> x array([[[[1], [2]], [[3], [4]]]], dtype=int32) >>> n = nn.Upsample(scale_factor=2, mode='nearest') >>> n(x).squeeze() array([[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], dtype=int32) >>> b = nn.Upsample(scale_factor=2, mode='linear') >>> b(x).squeeze() array([[1, 1.25, 1.75, 2], [1.5, 1.75, 2.25, 2.5], [2.5, 2.75, 3.25, 3.5], [3, 3.25, 3.75, 4]], dtype=float32) >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True) >>> b(x).squeeze() array([[1, 1.33333, 1.66667, 2], [1.66667, 2, 2.33333, 2.66667], [2.33333, 2.66667, 3, 3.33333], [3, 3.33333, 3.66667, 4]], dtype=float32)
方法