Skip to content

Commit 354bda8

Browse files
committed
refactor
1 parent 26951a1 commit 354bda8

File tree

1 file changed

+165
-77
lines changed

1 file changed

+165
-77
lines changed
Lines changed: 165 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from copy import deepcopy
2-
from math import inf
3-
from typing import (Any, Iterator, Optional, Type, Union)
1+
import copy
2+
from typing import (Any, Iterator, Optional, Type, Union, List, Dict)
43

54
import torch
65
import torch.nn as nn
@@ -10,8 +9,6 @@
109
import torch_xla
1110
import torch_xla.core.xla_model as xm
1211

13-
from .fsdp.xla_fully_sharded_data_parallel import _calc_grad_norm
14-
1512

1613
class ZeroRedundancyOptimizer(Optimizer):
1714
r"""
@@ -50,51 +47,138 @@ def __init__(
5047
grad_clipping: bool = True,
5148
max_norm: Optional[float] = None,
5249
pin_layout: bool = True,
50+
cc_op_groups: Optional[Any] = None,
51+
lazy_init: bool = False,
5352
**defaults: Any,
5453
):
55-
self.params = list(params)
56-
super().__init__(self.params, defaults)
57-
if isinstance(self.params[0], dict):
58-
self.params = [p for pg in self.params for p in pg['params']]
59-
60-
self.device = self.params[0].device
54+
super().__init__(params, defaults)
6155

62-
self.rank = xm.get_ordinal()
63-
self.world_size = xm.xrt_world_size()
64-
self.cc_op_groups = [list(range(self.world_size))]
56+
self.global_world_size = xm.xrt_world_size()
57+
self.global_rank = xm.get_ordinal()
58+
self._cc_op_groups = [list(range(self.global_world_size))
59+
] if cc_op_groups is None else cc_op_groups
6560

61+
self.optimizer_class = optimizer_class
62+
self.defaults = defaults
6663
self.optimizer_dtype = optimizer_dtype if optimizer_dtype is not None else torch.float32
6764
self.grad_clipping = grad_clipping
6865
self.max_norm = max_norm if max_norm is not None else 1.0
6966
self.pin_layout = pin_layout
7067

68+
self.inited = False
69+
if not lazy_init:
70+
self.init_zero()
71+
72+
def init_zero(self):
73+
self.local_world_size = len(self.cc_op_groups[0])
74+
self.local_rank = self.global_rank // len(self.cc_op_groups)
7175
# Shard parameters for use in optimizer
72-
self.sharded_params = []
73-
self._shard_parameters()
76+
sharded_param_groups = self._shard_parameters()
7477
# Optimizer initialization
75-
self.base_optimizer = optimizer_class(iter(self.sharded_params), **defaults)
78+
self.base_optimizer = self.optimizer_class(sharded_param_groups,
79+
**self.defaults)
80+
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
81+
self.inited = True
82+
83+
@property
84+
def cc_op_groups(self):
85+
return self._cc_op_groups
86+
87+
@cc_op_groups.setter
88+
def cc_op_groups(self, new_cc_op_groups):
89+
assert not self.inited, "already inited, cannot change cc_op_groups"
90+
self._cc_op_groups = new_cc_op_groups
91+
92+
@staticmethod
93+
def _sync_param_groups(
94+
src_param_groups: List[Dict[Any, Any]],
95+
dst_param_groups: List[Dict[Any, Any]],
96+
) -> None:
97+
r"""
98+
Syncs the attributes from the source parameter groups to the
99+
destination parameter groups, except the parameters.
100+
101+
Example attributes include learning rate or scheduler attributes. The
102+
two parameter groups should have the same length (i.e. same number of
103+
parameter groups).
104+
105+
Arguments:
106+
src_param_groups (list[dict]): parameter groups giving the
107+
attribute settings to copy.
108+
dst_param_groups (list[dict]): parameter groups giving the
109+
attribute settings to set.
110+
"""
111+
assert len(src_param_groups) == len(dst_param_groups), \
112+
"Mismatch between number of source and destination parameter groups"
113+
for src_param_group, dst_param_group in zip(src_param_groups,
114+
dst_param_groups):
115+
# Sync all attributes except the parameters
116+
for attr in filter(lambda x: x != "params", src_param_group.keys()):
117+
dst_param_group[attr] = src_param_group[attr]
76118

77119
def _shard_tensor(self, tensor: torch.Tensor):
78120
"""
79121
Get the shard of the input tensor.
80122
"""
81-
assert tensor.shape[0] % self.world_size == 0, "Not support padding now."
82-
tensor = tensor.chunk(self.world_size)[self.rank]
123+
assert tensor.shape[
124+
0] % self.local_world_size == 0, "Not support padding now."
125+
tensor = tensor.chunk(self.local_world_size)[self.local_rank]
83126
return tensor
84127

85128
def _shard_parameters(self):
86129
"""
87130
Shard all parameters.
88131
"""
89-
xm.unlazy(self.params)
90-
for param in self.params:
91-
shard_data = param.data.to(device="cpu") # move to cpu
92-
shard_data = self._shard_tensor(shard_data) # slice it
93-
if shard_data.dtype != self.optimizer_dtype:
94-
shard_data = shard_data.to(dtype=self.optimizer_dtype)
95-
shard_data = shard_data.to(device=self.device) # move to xla device
96-
shard = nn.Parameter(shard_data, requires_grad=param.requires_grad)
97-
self.sharded_params.append(shard)
132+
all_params = []
133+
for param_group in self.param_groups:
134+
for param in param_group['params']:
135+
all_params.append(param)
136+
137+
self.device = all_params[0].device
138+
xm.unlazy(all_params)
139+
140+
sharded_params_groups = []
141+
for param_group in self.param_groups:
142+
sharded_params = []
143+
for param in param_group['params']:
144+
shard_data = param.data.to(device="cpu") # move to cpu
145+
shard_data = self._shard_tensor(shard_data) # slice it
146+
if shard_data.dtype != self.optimizer_dtype:
147+
shard_data = shard_data.to(dtype=self.optimizer_dtype)
148+
shard_data = shard_data.to(device=self.device) # move to xla device
149+
shard = nn.Parameter(shard_data, requires_grad=param.requires_grad)
150+
sharded_params.append(shard)
151+
sharded_params_group = copy.copy(param_group)
152+
sharded_params_group['params'] = sharded_params
153+
sharded_params_groups.append(sharded_params_group)
154+
155+
return sharded_params_groups
156+
157+
@torch.no_grad()
158+
def _calc_grad_norm(
159+
self,
160+
norm_type: Union[float, int] = 2.0,
161+
) -> torch.Tensor:
162+
grads_for_norm = []
163+
for param_group in self.base_optimizer.param_groups:
164+
for p in param_group['params']:
165+
if p.grad is not None:
166+
grads_for_norm.append(p.grad.detach())
167+
# Norm parameters.
168+
if norm_type != 2.0:
169+
raise RuntimeError(f"only norm type 2 is supported, getting {norm_type}")
170+
total_norm = torch.zeros([], dtype=self.optimizer_dtype, device=self.device)
171+
for grad in grads_for_norm:
172+
grad_norm = (grad * grad).sum()
173+
total_norm += grad_norm
174+
# across all ranks as no pipeline parallel
175+
total_norm = xm.all_reduce(
176+
xm.REDUCE_SUM,
177+
total_norm,
178+
groups=[list(range(self.global_world_size))],
179+
pin_layout=self.pin_layout)
180+
total_norm = torch.pow(total_norm, 1.0 / norm_type)
181+
return total_norm
98182

99183
@torch.no_grad()
100184
def _clip_grad_norm(
@@ -109,55 +193,53 @@ def _clip_grad_norm(
109193
"""
110194
max_norm = float(max_norm)
111195
norm_type = float(norm_type)
112-
params_with_grad = [p for p in self.sharded_params if p.grad is not None]
113-
# Computes the max norm for this shard's gradients and sync's across workers
114-
local_norm = _calc_grad_norm(params_with_grad, norm_type)
115-
if norm_type == inf:
116-
total_norm = xm.all_reduce(
117-
xm.REDUCE_MAX,
118-
local_norm,
119-
groups=self.cc_op_groups,
120-
pin_layout=self.pin_layout)
121-
else:
122-
total_norm = xm.all_reduce(
123-
xm.REDUCE_SUM,
124-
local_norm**norm_type,
125-
groups=self.cc_op_groups,
126-
pin_layout=self.pin_layout)
127-
total_norm = total_norm**(1.0 / norm_type)
128-
129-
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
130-
clip_coef = torch.clip(max_norm / (total_norm + 1e-6), 0.0, 1.0)
131-
for p in params_with_grad:
132-
p.grad.detach().mul_(clip_coef)
196+
total_norm = self._calc_grad_norm(norm_type)
197+
198+
clip_coeff = torch.tensor(
199+
max_norm, device=self.device) / (
200+
total_norm + 1e-6)
201+
clip_value = torch.where(clip_coeff < 1, clip_coeff,
202+
torch.tensor(1., device=self.device))
203+
for param_group in self.base_optimizer.param_groups:
204+
for p in param_group['params']:
205+
if p.grad is not None:
206+
p.grad.detach().mul_(clip_value)
133207

134208
@torch.no_grad()
135209
def step(self, closure=None, **kwargs):
136210
"""
137211
Performs a single optimizer step and syncs parameters across all ranks.
138212
"""
213+
assert self.inited, "must call init_zero() first"
214+
139215
loss = None
140216
if closure is not None:
141217
with torch.enable_grad():
142218
loss = closure()
143219

220+
# sync to base optimizer
221+
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
222+
144223
# Reduce full gradients across ranks
145224
# Assign gradient shards to the respective parameter shards
146-
for param, shard in zip(self.params, self.sharded_params):
147-
if param.grad is not None:
148-
grad_shard = xm.reduce_scatter(
149-
xm.REDUCE_SUM,
150-
param.grad,
151-
scale=1.0 / self.world_size,
152-
scatter_dim=0,
153-
shard_count=self.world_size,
154-
pin_layout=self.pin_layout,
155-
groups=self.cc_op_groups,
156-
)
157-
158-
if grad_shard.dtype != self.optimizer_dtype:
159-
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
160-
shard.grad = grad_shard
225+
for param_group, sharded_param_group in zip(
226+
self.param_groups, self.base_optimizer.param_groups):
227+
for param, shard in zip(param_group['params'],
228+
sharded_param_group['params']):
229+
if param.grad is not None:
230+
grad_shard = xm.reduce_scatter(
231+
xm.REDUCE_SUM,
232+
param.grad,
233+
scale=1.0 / self.local_world_size,
234+
scatter_dim=0,
235+
shard_count=self.local_world_size,
236+
pin_layout=self.pin_layout,
237+
groups=self.cc_op_groups,
238+
)
239+
240+
if grad_shard.dtype != self.optimizer_dtype:
241+
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
242+
shard.grad = grad_shard
161243

162244
if self.grad_clipping:
163245
# Update unscale/clip with sub partitions
@@ -169,18 +251,24 @@ def step(self, closure=None, **kwargs):
169251
self.base_optimizer.zero_grad(set_to_none=True)
170252

171253
# All gather the new weights across the ranks and assign them to the full parameters
172-
for param, shard in zip(self.params, self.sharded_params):
173-
if param.grad is not None:
174-
shard_data = shard.data
175-
if param.dtype != self.optimizer_dtype:
176-
shard_data = shard_data.to(dtype=param.dtype)
177-
xm.all_gather(
178-
shard_data,
179-
dim=0,
180-
output=param.data,
181-
pin_layout=self.pin_layout,
182-
groups=self.cc_op_groups,
183-
)
254+
for param_group, sharded_param_group in zip(
255+
self.param_groups, self.base_optimizer.param_groups):
256+
for param, shard in zip(param_group['params'],
257+
sharded_param_group['params']):
258+
if param.grad is not None:
259+
shard_data = shard.data
260+
if param.dtype != self.optimizer_dtype:
261+
shard_data = shard_data.to(dtype=param.dtype)
262+
xm.all_gather(
263+
shard_data,
264+
dim=0,
265+
output=param.data,
266+
pin_layout=self.pin_layout,
267+
groups=self.cc_op_groups,
268+
)
269+
270+
# sync back
271+
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)
184272

185273
return loss
186274

@@ -190,7 +278,7 @@ def state_dict(self):
190278
return state_dict
191279

192280
def load_state_dict(self, state_dict):
193-
state_dict = deepcopy(state_dict)
281+
state_dict = copy.deepcopy(state_dict)
194282
base = state_dict.pop('base')
195283
super().load_state_dict(state_dict)
196284
self.base_optimizer.load_state_dict(base)

0 commit comments

Comments
 (0)