mlx.nn.RNN

目录

mlx.nn.RNN#

class RNN(input_size: int, hidden_size: int, bias: bool = True, nonlinearity: Callable | None = None)#

一个 Elman 循环层。

输入是一个形状为 NLDLD 的序列,其中

  • N 是可选的批量维度

  • L 是序列长度

  • D 是输入的特征维度

具体来说,对于序列长度轴上的每个元素,该层应用以下函数

\[h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\]

隐状态 \(h\) 的形状为 NHH,取决于输入是否批量化。返回每个时间步的隐状态,形状为 NLHLH

参数:
  • input_size (int) – 输入的维度,D

  • hidden_size (int) – 隐状态的维度,H

  • bias (bool, optional) – 是否使用偏置。默认值:True

  • nonlinearity (callable, optional) – 使用的非线性函数。如果为 None,则使用 func:tanh。默认值:None

方法