1- from copy import deepcopy
2- from math import inf
3- from typing import (Any , Iterator , Optional , Type , Union )
1+ import copy
2+ from typing import (Any , Iterator , Optional , Type , Union , List , Dict )
43
54import torch
65import torch .nn as nn
109import torch_xla
1110import torch_xla .core .xla_model as xm
1211
13- from .fsdp .xla_fully_sharded_data_parallel import _calc_grad_norm
14-
1512
1613class ZeroRedundancyOptimizer (Optimizer ):
1714 r"""
@@ -50,51 +47,138 @@ def __init__(
5047 grad_clipping : bool = True ,
5148 max_norm : Optional [float ] = None ,
5249 pin_layout : bool = True ,
50+ cc_op_groups : Optional [Any ] = None ,
51+ lazy_init : bool = False ,
5352 ** defaults : Any ,
5453 ):
55- self .params = list (params )
56- super ().__init__ (self .params , defaults )
57- if isinstance (self .params [0 ], dict ):
58- self .params = [p for pg in self .params for p in pg ['params' ]]
59-
60- self .device = self .params [0 ].device
54+ super ().__init__ (params , defaults )
6155
62- self .rank = xm .get_ordinal ()
63- self .world_size = xm .xrt_world_size ()
64- self .cc_op_groups = [list (range (self .world_size ))]
56+ self .global_world_size = xm .xrt_world_size ()
57+ self .global_rank = xm .get_ordinal ()
58+ self ._cc_op_groups = [list (range (self .global_world_size ))
59+ ] if cc_op_groups is None else cc_op_groups
6560
61+ self .optimizer_class = optimizer_class
62+ self .defaults = defaults
6663 self .optimizer_dtype = optimizer_dtype if optimizer_dtype is not None else torch .float32
6764 self .grad_clipping = grad_clipping
6865 self .max_norm = max_norm if max_norm is not None else 1.0
6966 self .pin_layout = pin_layout
7067
68+ self .inited = False
69+ if not lazy_init :
70+ self .init_zero ()
71+
72+ def init_zero (self ):
73+ self .local_world_size = len (self .cc_op_groups [0 ])
74+ self .local_rank = self .global_rank // len (self .cc_op_groups )
7175 # Shard parameters for use in optimizer
72- self .sharded_params = []
73- self ._shard_parameters ()
76+ sharded_param_groups = self ._shard_parameters ()
7477 # Optimizer initialization
75- self .base_optimizer = optimizer_class (iter (self .sharded_params ), ** defaults )
78+ self .base_optimizer = self .optimizer_class (sharded_param_groups ,
79+ ** self .defaults )
80+ self ._sync_param_groups (self .param_groups , self .base_optimizer .param_groups )
81+ self .inited = True
82+
83+ @property
84+ def cc_op_groups (self ):
85+ return self ._cc_op_groups
86+
87+ @cc_op_groups .setter
88+ def cc_op_groups (self , new_cc_op_groups ):
89+ assert not self .inited , "already inited, cannot change cc_op_groups"
90+ self ._cc_op_groups = new_cc_op_groups
91+
92+ @staticmethod
93+ def _sync_param_groups (
94+ src_param_groups : List [Dict [Any , Any ]],
95+ dst_param_groups : List [Dict [Any , Any ]],
96+ ) -> None :
97+ r"""
98+ Syncs the attributes from the source parameter groups to the
99+ destination parameter groups, except the parameters.
100+
101+ Example attributes include learning rate or scheduler attributes. The
102+ two parameter groups should have the same length (i.e. same number of
103+ parameter groups).
104+
105+ Arguments:
106+ src_param_groups (list[dict]): parameter groups giving the
107+ attribute settings to copy.
108+ dst_param_groups (list[dict]): parameter groups giving the
109+ attribute settings to set.
110+ """
111+ assert len (src_param_groups ) == len (dst_param_groups ), \
112+ "Mismatch between number of source and destination parameter groups"
113+ for src_param_group , dst_param_group in zip (src_param_groups ,
114+ dst_param_groups ):
115+ # Sync all attributes except the parameters
116+ for attr in filter (lambda x : x != "params" , src_param_group .keys ()):
117+ dst_param_group [attr ] = src_param_group [attr ]
76118
77119 def _shard_tensor (self , tensor : torch .Tensor ):
78120 """
79121 Get the shard of the input tensor.
80122 """
81- assert tensor .shape [0 ] % self .world_size == 0 , "Not support padding now."
82- tensor = tensor .chunk (self .world_size )[self .rank ]
123+ assert tensor .shape [
124+ 0 ] % self .local_world_size == 0 , "Not support padding now."
125+ tensor = tensor .chunk (self .local_world_size )[self .local_rank ]
83126 return tensor
84127
85128 def _shard_parameters (self ):
86129 """
87130 Shard all parameters.
88131 """
89- xm .unlazy (self .params )
90- for param in self .params :
91- shard_data = param .data .to (device = "cpu" ) # move to cpu
92- shard_data = self ._shard_tensor (shard_data ) # slice it
93- if shard_data .dtype != self .optimizer_dtype :
94- shard_data = shard_data .to (dtype = self .optimizer_dtype )
95- shard_data = shard_data .to (device = self .device ) # move to xla device
96- shard = nn .Parameter (shard_data , requires_grad = param .requires_grad )
97- self .sharded_params .append (shard )
132+ all_params = []
133+ for param_group in self .param_groups :
134+ for param in param_group ['params' ]:
135+ all_params .append (param )
136+
137+ self .device = all_params [0 ].device
138+ xm .unlazy (all_params )
139+
140+ sharded_params_groups = []
141+ for param_group in self .param_groups :
142+ sharded_params = []
143+ for param in param_group ['params' ]:
144+ shard_data = param .data .to (device = "cpu" ) # move to cpu
145+ shard_data = self ._shard_tensor (shard_data ) # slice it
146+ if shard_data .dtype != self .optimizer_dtype :
147+ shard_data = shard_data .to (dtype = self .optimizer_dtype )
148+ shard_data = shard_data .to (device = self .device ) # move to xla device
149+ shard = nn .Parameter (shard_data , requires_grad = param .requires_grad )
150+ sharded_params .append (shard )
151+ sharded_params_group = copy .copy (param_group )
152+ sharded_params_group ['params' ] = sharded_params
153+ sharded_params_groups .append (sharded_params_group )
154+
155+ return sharded_params_groups
156+
157+ @torch .no_grad ()
158+ def _calc_grad_norm (
159+ self ,
160+ norm_type : Union [float , int ] = 2.0 ,
161+ ) -> torch .Tensor :
162+ grads_for_norm = []
163+ for param_group in self .base_optimizer .param_groups :
164+ for p in param_group ['params' ]:
165+ if p .grad is not None :
166+ grads_for_norm .append (p .grad .detach ())
167+ # Norm parameters.
168+ if norm_type != 2.0 :
169+ raise RuntimeError (f"only norm type 2 is supported, getting { norm_type } " )
170+ total_norm = torch .zeros ([], dtype = self .optimizer_dtype , device = self .device )
171+ for grad in grads_for_norm :
172+ grad_norm = (grad * grad ).sum ()
173+ total_norm += grad_norm
174+ # across all ranks as no pipeline parallel
175+ total_norm = xm .all_reduce (
176+ xm .REDUCE_SUM ,
177+ total_norm ,
178+ groups = [list (range (self .global_world_size ))],
179+ pin_layout = self .pin_layout )
180+ total_norm = torch .pow (total_norm , 1.0 / norm_type )
181+ return total_norm
98182
99183 @torch .no_grad ()
100184 def _clip_grad_norm (
@@ -109,55 +193,53 @@ def _clip_grad_norm(
109193 """
110194 max_norm = float (max_norm )
111195 norm_type = float (norm_type )
112- params_with_grad = [p for p in self .sharded_params if p .grad is not None ]
113- # Computes the max norm for this shard's gradients and sync's across workers
114- local_norm = _calc_grad_norm (params_with_grad , norm_type )
115- if norm_type == inf :
116- total_norm = xm .all_reduce (
117- xm .REDUCE_MAX ,
118- local_norm ,
119- groups = self .cc_op_groups ,
120- pin_layout = self .pin_layout )
121- else :
122- total_norm = xm .all_reduce (
123- xm .REDUCE_SUM ,
124- local_norm ** norm_type ,
125- groups = self .cc_op_groups ,
126- pin_layout = self .pin_layout )
127- total_norm = total_norm ** (1.0 / norm_type )
128-
129- # Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
130- clip_coef = torch .clip (max_norm / (total_norm + 1e-6 ), 0.0 , 1.0 )
131- for p in params_with_grad :
132- p .grad .detach ().mul_ (clip_coef )
196+ total_norm = self ._calc_grad_norm (norm_type )
197+
198+ clip_coeff = torch .tensor (
199+ max_norm , device = self .device ) / (
200+ total_norm + 1e-6 )
201+ clip_value = torch .where (clip_coeff < 1 , clip_coeff ,
202+ torch .tensor (1. , device = self .device ))
203+ for param_group in self .base_optimizer .param_groups :
204+ for p in param_group ['params' ]:
205+ if p .grad is not None :
206+ p .grad .detach ().mul_ (clip_value )
133207
134208 @torch .no_grad ()
135209 def step (self , closure = None , ** kwargs ):
136210 """
137211 Performs a single optimizer step and syncs parameters across all ranks.
138212 """
213+ assert self .inited , "must call init_zero() first"
214+
139215 loss = None
140216 if closure is not None :
141217 with torch .enable_grad ():
142218 loss = closure ()
143219
220+ # sync to base optimizer
221+ self ._sync_param_groups (self .param_groups , self .base_optimizer .param_groups )
222+
144223 # Reduce full gradients across ranks
145224 # Assign gradient shards to the respective parameter shards
146- for param , shard in zip (self .params , self .sharded_params ):
147- if param .grad is not None :
148- grad_shard = xm .reduce_scatter (
149- xm .REDUCE_SUM ,
150- param .grad ,
151- scale = 1.0 / self .world_size ,
152- scatter_dim = 0 ,
153- shard_count = self .world_size ,
154- pin_layout = self .pin_layout ,
155- groups = self .cc_op_groups ,
156- )
157-
158- if grad_shard .dtype != self .optimizer_dtype :
159- grad_shard = grad_shard .to (dtype = self .optimizer_dtype )
160- shard .grad = grad_shard
225+ for param_group , sharded_param_group in zip (
226+ self .param_groups , self .base_optimizer .param_groups ):
227+ for param , shard in zip (param_group ['params' ],
228+ sharded_param_group ['params' ]):
229+ if param .grad is not None :
230+ grad_shard = xm .reduce_scatter (
231+ xm .REDUCE_SUM ,
232+ param .grad ,
233+ scale = 1.0 / self .local_world_size ,
234+ scatter_dim = 0 ,
235+ shard_count = self .local_world_size ,
236+ pin_layout = self .pin_layout ,
237+ groups = self .cc_op_groups ,
238+ )
239+
240+ if grad_shard .dtype != self .optimizer_dtype :
241+ grad_shard = grad_shard .to (dtype = self .optimizer_dtype )
242+ shard .grad = grad_shard
161243
162244 if self .grad_clipping :
163245 # Update unscale/clip with sub partitions
@@ -169,18 +251,24 @@ def step(self, closure=None, **kwargs):
169251 self .base_optimizer .zero_grad (set_to_none = True )
170252
171253 # All gather the new weights across the ranks and assign them to the full parameters
172- for param , shard in zip (self .params , self .sharded_params ):
173- if param .grad is not None :
174- shard_data = shard .data
175- if param .dtype != self .optimizer_dtype :
176- shard_data = shard_data .to (dtype = param .dtype )
177- xm .all_gather (
178- shard_data ,
179- dim = 0 ,
180- output = param .data ,
181- pin_layout = self .pin_layout ,
182- groups = self .cc_op_groups ,
183- )
254+ for param_group , sharded_param_group in zip (
255+ self .param_groups , self .base_optimizer .param_groups ):
256+ for param , shard in zip (param_group ['params' ],
257+ sharded_param_group ['params' ]):
258+ if param .grad is not None :
259+ shard_data = shard .data
260+ if param .dtype != self .optimizer_dtype :
261+ shard_data = shard_data .to (dtype = param .dtype )
262+ xm .all_gather (
263+ shard_data ,
264+ dim = 0 ,
265+ output = param .data ,
266+ pin_layout = self .pin_layout ,
267+ groups = self .cc_op_groups ,
268+ )
269+
270+ # sync back
271+ self ._sync_param_groups (self .base_optimizer .param_groups , self .param_groups )
184272
185273 return loss
186274
@@ -190,7 +278,7 @@ def state_dict(self):
190278 return state_dict
191279
192280 def load_state_dict (self , state_dict ):
193- state_dict = deepcopy (state_dict )
281+ state_dict = copy . deepcopy (state_dict )
194282 base = state_dict .pop ('base' )
195283 super ().load_state_dict (state_dict )
196284 self .base_optimizer .load_state_dict (base )
0 commit comments