GRU算法解析

图1表示了GRU的网络模型,看到每个记忆单元有两个门控单元,更新门和遗忘门,公式(1)-(4)可以表示具体的操作,

GRU

%% GRU算法构成 \(\begin{equation} r_t = \sigma(W_r \cdot [h_{t-1},x_t]+ b_r) \tag{1} \end{equation}\) \(\begin{equation} z_t = \sigma(W_z \cdot [h_{t-1},x_t]+ b_z) \tag{2} \end{equation}\) \(\begin{equation} \hat{h_t} = tanh(W_h\cdot[h_{t-1},x_t]+b_h) \tag{3} \end{equation}\) \(\begin{equation} h_t = (1-z_t)h{t-1} + z_t\hat{h_t} \tag{4} \end{equation}\) 在上述的公式中 \(r_t\)是遗忘门,\(z_t\)是更新门,\(\hat{h_t}\)代表隐藏状态的更新值, \(h_t\)\(x_t\)表示当前时刻的输入和隐藏状态,\(\sigma\)表示的是sigmoid函数。 \(\begin{equation} sigmoid : \sigma(x) = \frac{1}{1-e^{-x}} \tag{} \end{equation}\) 遗忘门和更新门通过sigmoid函数进行激活,隐藏状态的更新值通过tanh函数, \(W_r,W_z,W_h\)分别代表重置门\(r_t\),更新门\(z_t\)以及更新隐藏状态\(\hat{h_t}\)的权重矩阵,\(b_r,b_z,b_h\)分别是\(r_t,z_t,\hat{h_t}\)的偏置矩阵,更新门\(z_t\)代表有多少比例的t-1时刻的隐藏状态\(h_{t-1}\)被更新到当前状态,重置门\(r_t\)重置了\(h_{t-1}\)的信息\(r_t\)

通用的RNN模块调用方法

调用格式torch.nn.RNN() 或者torch.nn.GRU()

其中的参数有如下:

image-20230224232010563

其中需要注意的是,input_size是指输入数据的特征维度,比如传感器的时序数据就是一维的,表格类的可以是2维的。这里不是指时序的序列长度。

调用时可以通过batch_first参数来调整seq和batch_size的位置

image-20230224232355005

bidirectional参数可以选择双向的RNN网络

参数个数的计算

GRU网络的计算

3[(d+h)h+h] L + [(2d+h)+1]h *(L-1)

其中d为输入的大小,h为隐藏层的大小\(h_{dim}\) L为GRU的num_layers

如果时LSTM前面的3 改为4

如果是RNN改为2