11import torch
22import torch .nn as nn
33import torch .nn .functional as F
4+ from typing import overload
45
56from torch_xla .experimental .scan import scan
67
78
8- class GRU (nn .Module ):
9+ class GRU (nn .GRU ):
910 r"""
1011 PyTorch/XLA GRU implemented using scan.
1112
@@ -52,47 +53,24 @@ class GRU(nn.Module):
5253
5354 """
5455
56+ @overload
5557 def __init__ (self ,
56- input_size ,
57- hidden_size ,
58- num_layers = 1 ,
59- bias = True ,
60- dropout = 0.0 ):
61- super ().__init__ ()
62-
63- self .input_size = input_size
64- self .hidden_size = hidden_size
65- self .num_layers = num_layers
66- self .bias = bias
67- self .dropout = dropout
68-
69- # Create parameters for each layer.
70- # For layer 0, the input dimension is `input_size`, otherwise it's `hidden_size`.
71- self .weight_ih = nn .ParameterList ()
72- self .weight_hh = nn .ParameterList ()
73- if bias :
74- self .bias_ih = nn .ParameterList ()
75- self .bias_hh = nn .ParameterList ()
76-
77- for layer in range (num_layers ):
78- layer_input_size = input_size if layer == 0 else hidden_size
79- # weight_ih: combines weights for reset, update, and new gates.
80- w_ih = nn .Parameter (torch .Tensor (3 * hidden_size , layer_input_size ))
81- w_hh = nn .Parameter (torch .Tensor (3 * hidden_size , hidden_size ))
82- self .weight_ih .append (w_ih )
83- self .weight_hh .append (w_hh )
84- if bias :
85- b_ih = nn .Parameter (torch .Tensor (3 * hidden_size ))
86- b_hh = nn .Parameter (torch .Tensor (3 * hidden_size ))
87- self .bias_ih .append (b_ih )
88- self .bias_hh .append (b_hh )
89- self .reset_parameters ()
90-
91- def reset_parameters (self ):
92- # Initialize parameters uniformly as in the upstream PyTorch GRU.
93- stdv = 1.0 / (self .hidden_size ** 0.5 )
94- for weight in self .parameters ():
95- weight .data .uniform_ (- stdv , stdv )
58+ input_size : int ,
59+ hidden_size : int ,
60+ num_layers : int = 1 ,
61+ bias : bool = True ,
62+ dropout : float = 0.0 ):
63+ pass
64+
65+ def __init__ (self , * args , ** kwargs ):
66+ assert not kwargs .get ('batch_first' , False ), \
67+ "GRU only supports batch_first=False (seq_len, batch, input_size)."
68+ assert not kwargs .get ('bidirectional' , False ), \
69+ "GRU only supports unidirectional GRU."
70+ assert kwargs .get ('proj_size' , 0 ) == 0 , \
71+ "GRU only supports no projection."
72+
73+ super ().__init__ (* args , ** kwargs )
9674
9775 def forward (self , input , hx = None ):
9876 """
@@ -119,12 +97,12 @@ def forward(self, input, hx=None):
11997 for layer in range (self .num_layers ):
12098 init = {
12199 'h' : hx [layer ],
122- 'w_ih' : self . weight_ih [ layer ] ,
123- 'w_hh' : self . weight_hh [ layer ]
100+ 'w_ih' : getattr ( self , f'weight_ih_l { layer } ' ) ,
101+ 'w_hh' : getattr ( self , f'weight_hh_l { layer } ' )
124102 }
125103 if self .bias :
126- init ['b_ih' ] = self . bias_ih [ layer ]
127- init ['b_hh' ] = self . bias_hh [ layer ]
104+ init ['b_ih' ] = getattr ( self , f'bias_ih_l { layer } ' , None )
105+ init ['b_hh' ] = getattr ( self , f'bias_hh_l { layer } ' , None )
128106
129107 # Define the step function for scanning over time.
130108 # x_t: (batch, current_input_size)
@@ -155,15 +133,8 @@ def step_fn(carry, x_t):
155133 # Update hidden state
156134 h_new = (1 - z ) * n + z * h
157135
158- carry_new = {
159- 'h' : h_new ,
160- 'w_ih' : w_ih ,
161- 'w_hh' : w_hh ,
162- }
163- if b_ih is not None :
164- carry_new ['b_ih' ] = b_ih
165- if b_hh is not None :
166- carry_new ['b_hh' ] = b_hh
136+ carry_new = {** carry , 'h' : h_new }
137+
167138 return carry_new , h_new
168139
169140 # Use scan to iterate over the time dimension.
0 commit comments