mlx.nn.LSTM#
- class LSTM(input_size: int, hidden_size: int, bias: bool = True)#
一个 LSTM 循环层。
输入形状为
NLD
或LD
,其中N
是可选的批处理维度L
是序列长度D
是输入的特征维度
具体来说,对于序列中的每个元素,该层计算
\[\begin{split}\begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) \end{aligned}\end{split}\]隐藏状态 \(h\) 和单元状态 \(c\) 的形状为
NH
或H
,取决于输入是否进行了批处理。该层返回两个数组,分别是每个时间步的隐藏状态和单元状态,它们的形状均为
NLH
或LH
。方法