mlx.nn.RNN#
- class RNN(input_size: int, hidden_size: int, bias: bool = True, nonlinearity: Callable | None = None)#
一个 Elman 循环层。
输入是一个形状为
NLD或LD的序列,其中N是可选的批量维度L是序列长度D是输入的特征维度
具体来说,对于序列长度轴上的每个元素,该层应用以下函数
\[h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\]隐状态 \(h\) 的形状为
NH或H,取决于输入是否批量化。返回每个时间步的隐状态,形状为NLH或LH。- 参数:
方法