Skip to content

Commit 145c33a

Browse files
authored
Rewrite scan-based GRU based on nn.GRU (#8914)
1 parent 980ead5 commit 145c33a

File tree

2 files changed

+69
-66
lines changed

2 files changed

+69
-66
lines changed

test/test_gru.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
43
import torch_xla
54
from torch_xla.experimental.gru import GRU
65

@@ -27,16 +26,9 @@ def build_models(self, input_size, hidden_size, num_layers, bias):
2726
input_size, hidden_size, num_layers=num_layers, bias=bias, dropout=0.0)
2827

2928
# Copy parameters from the upstream GRU to our scan-based GRU.
30-
for layer in range(num_layers):
31-
scan_gru.weight_ih[layer].data.copy_(
32-
getattr(gru, f'weight_ih_l{layer}').data)
33-
scan_gru.weight_hh[layer].data.copy_(
34-
getattr(gru, f'weight_hh_l{layer}').data)
35-
if gru.bias:
36-
scan_gru.bias_ih[layer].data.copy_(
37-
getattr(gru, f'bias_ih_l{layer}').data)
38-
scan_gru.bias_hh[layer].data.copy_(
39-
getattr(gru, f'bias_hh_l{layer}').data)
29+
# This ensures that the scan-based GRU has the same parameters as the
30+
# upstream GRU and both models are parameterized the same way.
31+
scan_gru.load_state_dict(gru.state_dict(), strict=True)
4032

4133
return gru, scan_gru
4234

@@ -78,7 +70,7 @@ def check_gradients(self,
7870
for layer in range(num_layers):
7971
for name in params_to_check:
8072
param1 = getattr(gru, f'{name}_l{layer}')
81-
param2 = getattr(scan_gru, name)[layer]
73+
param2 = getattr(scan_gru, f'{name}_l{layer}')
8274
torch.testing.assert_close(
8375
param1.grad,
8476
param2.grad,
@@ -88,6 +80,46 @@ def check_gradients(self,
8880
atol=atol,
8981
rtol=rtol)
9082

83+
def test_scan_gru_and_upstream_gru_parameter_independency(self):
84+
"""
85+
Ensures that the parameters of the scan-based GRU and upstream GRU are independent even the parameters of the scan-based GRU are initialized using the upstream GRU.
86+
"""
87+
input_size, hidden_size, num_layers = 16, 32, 2
88+
gru, scan_gru = self.build_models(input_size, hidden_size, num_layers, True)
89+
gru = gru.cpu()
90+
scan_gru = scan_gru.to('xla')
91+
torch_xla.sync()
92+
93+
with torch.no_grad():
94+
gru_weight_ih_l0 = gru.state_dict()['weight_ih_l0']
95+
scan_gru_weight_ih_l0 = scan_gru.state_dict()['weight_ih_l0']
96+
97+
# Compare the parameters of the GRU and scan-based GRU before changing.
98+
torch.testing.assert_close(
99+
gru_weight_ih_l0,
100+
scan_gru_weight_ih_l0,
101+
msg=lambda msg: f"weight_ih_l0 mismatch. {msg}",
102+
check_device=False)
103+
104+
# Change the parameters of the GRU with random numbers.
105+
gru_weight_ih_l0.uniform_(-1, 1)
106+
107+
# Assert not close after the change.
108+
try:
109+
torch.testing.assert_close(
110+
gru_weight_ih_l0,
111+
scan_gru_weight_ih_l0,
112+
msg=lambda msg: f"weight_ih_l0 mismatch. {msg}",
113+
check_device=False)
114+
raise AssertionError(
115+
"weight_ih_l0 should not be close after changing the GRU parameters."
116+
)
117+
except AssertionError as e:
118+
if str(e).startswith("weight_ih_l0 mismatch."):
119+
pass
120+
else:
121+
raise e
122+
91123
@parameterized.parameters(True, False)
92124
def test_scan_gru_vs_pytorch_xla_for_loop(self, bias):
93125
"""

torch_xla/experimental/gru.py

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
from typing import overload
45

56
from torch_xla.experimental.scan import scan
67

78

8-
class GRU(nn.Module):
9+
class GRU(nn.GRU):
910
r"""
1011
PyTorch/XLA GRU implemented using scan.
1112
@@ -52,47 +53,24 @@ class GRU(nn.Module):
5253
5354
"""
5455

56+
@overload
5557
def __init__(self,
56-
input_size,
57-
hidden_size,
58-
num_layers=1,
59-
bias=True,
60-
dropout=0.0):
61-
super().__init__()
62-
63-
self.input_size = input_size
64-
self.hidden_size = hidden_size
65-
self.num_layers = num_layers
66-
self.bias = bias
67-
self.dropout = dropout
68-
69-
# Create parameters for each layer.
70-
# For layer 0, the input dimension is `input_size`, otherwise it's `hidden_size`.
71-
self.weight_ih = nn.ParameterList()
72-
self.weight_hh = nn.ParameterList()
73-
if bias:
74-
self.bias_ih = nn.ParameterList()
75-
self.bias_hh = nn.ParameterList()
76-
77-
for layer in range(num_layers):
78-
layer_input_size = input_size if layer == 0 else hidden_size
79-
# weight_ih: combines weights for reset, update, and new gates.
80-
w_ih = nn.Parameter(torch.Tensor(3 * hidden_size, layer_input_size))
81-
w_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
82-
self.weight_ih.append(w_ih)
83-
self.weight_hh.append(w_hh)
84-
if bias:
85-
b_ih = nn.Parameter(torch.Tensor(3 * hidden_size))
86-
b_hh = nn.Parameter(torch.Tensor(3 * hidden_size))
87-
self.bias_ih.append(b_ih)
88-
self.bias_hh.append(b_hh)
89-
self.reset_parameters()
90-
91-
def reset_parameters(self):
92-
# Initialize parameters uniformly as in the upstream PyTorch GRU.
93-
stdv = 1.0 / (self.hidden_size**0.5)
94-
for weight in self.parameters():
95-
weight.data.uniform_(-stdv, stdv)
58+
input_size: int,
59+
hidden_size: int,
60+
num_layers: int = 1,
61+
bias: bool = True,
62+
dropout: float = 0.0):
63+
pass
64+
65+
def __init__(self, *args, **kwargs):
66+
assert not kwargs.get('batch_first', False), \
67+
"GRU only supports batch_first=False (seq_len, batch, input_size)."
68+
assert not kwargs.get('bidirectional', False), \
69+
"GRU only supports unidirectional GRU."
70+
assert kwargs.get('proj_size', 0) == 0, \
71+
"GRU only supports no projection."
72+
73+
super().__init__(*args, **kwargs)
9674

9775
def forward(self, input, hx=None):
9876
"""
@@ -119,12 +97,12 @@ def forward(self, input, hx=None):
11997
for layer in range(self.num_layers):
12098
init = {
12199
'h': hx[layer],
122-
'w_ih': self.weight_ih[layer],
123-
'w_hh': self.weight_hh[layer]
100+
'w_ih': getattr(self, f'weight_ih_l{layer}'),
101+
'w_hh': getattr(self, f'weight_hh_l{layer}')
124102
}
125103
if self.bias:
126-
init['b_ih'] = self.bias_ih[layer]
127-
init['b_hh'] = self.bias_hh[layer]
104+
init['b_ih'] = getattr(self, f'bias_ih_l{layer}', None)
105+
init['b_hh'] = getattr(self, f'bias_hh_l{layer}', None)
128106

129107
# Define the step function for scanning over time.
130108
# x_t: (batch, current_input_size)
@@ -155,15 +133,8 @@ def step_fn(carry, x_t):
155133
# Update hidden state
156134
h_new = (1 - z) * n + z * h
157135

158-
carry_new = {
159-
'h': h_new,
160-
'w_ih': w_ih,
161-
'w_hh': w_hh,
162-
}
163-
if b_ih is not None:
164-
carry_new['b_ih'] = b_ih
165-
if b_hh is not None:
166-
carry_new['b_hh'] = b_hh
136+
carry_new = {**carry, 'h': h_new}
137+
167138
return carry_new, h_new
168139

169140
# Use scan to iterate over the time dimension.

0 commit comments

Comments
 (0)