mlx.nn.RNN#
- class RNN(input_size: int, hidden_size: int, bias: bool = True, nonlinearity: Callable | None = None)#
一个 Elman 循环层。
输入是一个形状为
NLD
或LD
的序列,其中N
是可选的批量维度L
是序列长度D
是输入的特征维度
具体来说,对于序列长度轴上的每个元素,该层应用以下函数
\[h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\]隐状态 \(h\) 的形状为
NH
或H
,取决于输入是否批量化。返回每个时间步的隐状态,形状为NLH
或LH
。- 参数:
方法