mlx.nn.LSTM

目录

mlx.nn.LSTM#

class LSTM(input_size: int, hidden_size: int, bias: bool = True)#

一个 LSTM 循环层。

输入形状为 NLDLD,其中

  • N 是可选的批处理维度

  • L 是序列长度

  • D 是输入的特征维度

具体来说,对于序列中的每个元素,该层计算

\[\begin{split}\begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) \end{aligned}\end{split}\]

隐藏状态 \(h\) 和单元状态 \(c\) 的形状为 NHH,取决于输入是否进行了批处理。

该层返回两个数组,分别是每个时间步的隐藏状态和单元状态,它们的形状均为 NLHLH

参数:
  • input_size (int) – 输入的维度,D

  • hidden_size (int) – 隐藏状态的维度,H

  • bias (bool) – 是否使用偏置。默认值:True

方法