mlx.nn.GRU

目录

mlx.nn.GRU#

class GRU(input_size: int, hidden_size: int, bias: bool = True)#

一个门控循环单元(GRU)RNN 层。

输入形状为 NLDLD,其中

  • N 是可选的批处理维度

  • L 是序列长度

  • D 是输入的特征维度

具体来说,对于序列中的每个元素,该层计算

\[\begin{split}\begin{aligned} r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t \end{aligned}\end{split}\]

隐藏状态 \(h\) 的形状为 NHH,具体取决于输入是否经过批处理。返回每个时间步的隐藏状态,形状为 NLHLH

参数:
  • input_size (int) – 输入的维度,即 D

  • hidden_size (int) – 隐藏状态的维度,即 H

  • bias (bool) – 是否使用偏置项。默认值: True

方法