mlx.nn.GRU#
- class GRU(input_size: int, hidden_size: int, bias: bool = True)#
一个门控循环单元(GRU)RNN 层。
输入形状为
NLD
或LD
,其中N
是可选的批处理维度L
是序列长度D
是输入的特征维度
具体来说,对于序列中的每个元素,该层计算
\[\begin{split}\begin{aligned} r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t \end{aligned}\end{split}\]隐藏状态 \(h\) 的形状为
NH
或H
,具体取决于输入是否经过批处理。返回每个时间步的隐藏状态,形状为NLH
或LH
。方法