-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Fixes NAG optimizer #15543 #16053
Fixes NAG optimizer #15543 #16053
Conversation
|
Hi @zhanghang1989 , is there any difference between import mxnet as mx
import time
T = 1000
N = 1000
while 1:
ti = time.time()
a = mx.nd.arange(N)
for i in range(T):
a += 1
mx.nd.waitall()
print('a += b: ', time.time() - ti)
ti = time.time()
a = mx.nd.arange(N)
for i in range(T):
a[:] += 1
mx.nd.waitall()
print('a[:] += b: ', time.time() - ti)Output: |
@zhanghang1989 The update rule in this PR is the following - mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]
+(param_momentum+1)*(mom_data[i]
-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));this update rule is same as the following psuedocode - which when simplified, translates to ( it is the same rule used in keras as well - https://stats.stackexchange.com/questions/179915/whats-the-difference-between-momentum-based-gradient-descent-and-nesterovs-acc) |
zhanghang1989
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight update is correct. Please fix the mentum update in the end.
yes, will change the momentum update state |
I am not familiar with symbol API. Just write some pseudocode to show how NAG works :) |
a3500d7 to
b644e5f
Compare
|
Thanks @zhanghang1989 and @anirudhacharya |
|
@anirudhacharya perl gpu tests are failing : http://jenkins.mxnet-ci.amazon-ml.com/blue/rest/organizations/jenkins/pipelines/mxnet-validation/pipelines/unix-gpu/branches/master/runs/1029/nodes/304/steps/568/log/?start=0 , |
* fix update rules * readable updates in unit test * mom update


Description
Fixes #15543
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
For review - @zhanghang1989 @apeforest @eric-haibin-lin