diff --git a/api/api.go b/api/api.go index 468f346a..8ec6217a 100644 --- a/api/api.go +++ b/api/api.go @@ -101,7 +101,7 @@ func loginRoute(services service.Services, e *echo.Echo) { } func verificationRoute(services service.Services, e *echo.Echo) { - handler := handler.NewVerification(e, services.Verification) + handler := handler.NewVerification(e, services.Verification, services.Device) handler.RegisterRoutes(e.Group("/verification")) } diff --git a/api/config.go b/api/config.go index 22a974b5..9a0d8273 100644 --- a/api/config.go +++ b/api/config.go @@ -38,7 +38,11 @@ func NewServices(config APIConfig, repos repository.Repositories) service.Servic verificationRepos := repository.Repositories{Contact: repos.Contact, User: repos.User, Device: repos.Device} verification := service.NewVerification(verificationRepos) - auth := service.NewAuth(repos, fingerprint, verification) + // device service + deviceRepos := repository.Repositories{Device: repos.Device} + device := service.NewDevice(deviceRepos, fingerprint) + + auth := service.NewAuth(repos, verification, device) apiKey := service.NewAPIKeyStrategy(repos.Auth) cost := service.NewCost(config.Redis) executor := service.NewExecutor() @@ -61,5 +65,6 @@ func NewServices(config APIConfig, repos repository.Repositories) service.Servic Transaction: transaction, User: user, Verification: verification, + Device: device, } } diff --git a/api/handler/http_error.go b/api/handler/http_error.go index 5a339812..044aa47d 100644 --- a/api/handler/http_error.go +++ b/api/handler/http_error.go @@ -90,3 +90,10 @@ func LinkExpired(c echo.Context, message ...string) error { } return c.JSON(http.StatusForbidden, JSONError{Message: "Forbidden", Code: "LINK_EXPIRED"}) } + +func InvalidEmail(c echo.Context, message ...string) error { + if len(message) > 0 { + return c.JSON(http.StatusUnprocessableEntity, JSONError{Message: strings.Join(message, " "), Code: "INVALID_EMAIL"}) + } + return c.JSON(http.StatusUnprocessableEntity, JSONError{Message: "Invalid email", Code: "INVALID_EMAIL"}) +} diff --git a/api/handler/login.go b/api/handler/login.go index dd1657e9..17098f97 100644 --- a/api/handler/login.go +++ b/api/handler/login.go @@ -67,10 +67,14 @@ func (l login) VerifySignature(c echo.Context) error { body.Nonce = string(decodedNonce) resp, err := l.Service.VerifySignedPayload(body) - if err != nil && strings.Contains(err.Error(), "unknown device") { - return Unprocessable(c) - } if err != nil { + if strings.Contains(err.Error(), "unknown device") { + return Unprocessable(c) + } + if strings.Contains(err.Error(), "invalid email") { + return InvalidEmail(c) + } + LogStringError(c, err, "login: verify signature") return BadRequestError(c, "Invalid Payload") } diff --git a/api/handler/verification.go b/api/handler/verification.go index 1dc092a5..02f2df34 100644 --- a/api/handler/verification.go +++ b/api/handler/verification.go @@ -19,12 +19,13 @@ type Verification interface { } type verification struct { - service service.Verification - group *echo.Group + service service.Verification + deviceService service.Device + group *echo.Group } -func NewVerification(route *echo.Echo, service service.Verification) Verification { - return &verification{service, nil} +func NewVerification(route *echo.Echo, service service.Verification, deviceService service.Device) Verification { + return &verification{service, deviceService, nil} } func (v verification) VerifyEmail(c echo.Context) error { @@ -39,7 +40,7 @@ func (v verification) VerifyEmail(c echo.Context) error { func (v verification) VerifyDevice(c echo.Context) error { token := c.QueryParam("token") - err := v.service.VerifyDevice(token) + err := v.deviceService.VerifyDevice(token) if err != nil { LogStringError(c, err, "verification: device verification") return BadRequestError(c) diff --git a/pkg/model/entity.go b/pkg/model/entity.go index d8a8e54f..edce1d9a 100644 --- a/pkg/model/entity.go +++ b/pkg/model/entity.go @@ -20,6 +20,7 @@ type User struct { FirstName string `json:"firstName" db:"first_name"` MiddleName string `json:"middleName" db:"middle_name"` LastName string `json:"lastName" db:"last_name"` + Email string `json:"email"` } // See PLATFORM in Migrations 0005 diff --git a/pkg/model/user.go b/pkg/model/user.go index c9e16855..ae622e16 100644 --- a/pkg/model/user.go +++ b/pkg/model/user.go @@ -10,12 +10,12 @@ type WalletSignaturePayload struct { } type FingerprintPayload struct { - VisitorID string `json:"visitorId" validate:"required"` - RequestID string `json:"requestId" validate:"required"` + VisitorID string `json:"visitorId"` + RequestID string `json:"requestId"` } type WalletSignaturePayloadSigned struct { Nonce string `json:"nonce" validate:"required"` Signature string `json:"signature" validate:"required"` - Fingerprint FingerprintPayload `json:"fingerprint" validate:"required"` + Fingerprint FingerprintPayload `json:"fingerprint"` } diff --git a/pkg/service/auth.go b/pkg/service/auth.go index e1b825ad..e500d038 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -12,7 +12,6 @@ import ( "github.com/String-xyz/string-api/pkg/repository" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" - "github.com/lib/pq" "github.com/pkg/errors" ) @@ -51,7 +50,7 @@ type Auth interface { // if signaure is valid it returns a JWT to authenticate the user VerifySignedPayload(model.WalletSignaturePayloadSigned) (UserCreateResponse, error) - GenerateJWT(model.Device) (JWT, error) + GenerateJWT(string, ...model.Device) (JWT, error) ValidateAPIKey(key string) bool RefreshToken(token string, walletAddress string) (UserCreateResponse, error) InvalidateRefreshToken(token string) error @@ -59,13 +58,13 @@ type Auth interface { type auth struct { repos repository.Repositories - fingerprint Fingerprint verification Verification + device Device } // reusing UserRepos here -func NewAuth(r repository.Repositories, f Fingerprint, v Verification) Auth { - return &auth{r, f, v} +func NewAuth(r repository.Repositories, v Verification, d Device) Auth { + return &auth{r, v, d} } func (a auth) PayloadToSign(walletAddress string) (SignablePayload, error) { @@ -107,56 +106,36 @@ func (a auth) VerifySignedPayload(request model.WalletSignaturePayloadSigned) (U return resp, common.StringError(err) } - created, device, err := a.createDeviceIfNeeded(user.ID, request.Fingerprint.VisitorID, request.Fingerprint.RequestID) - if err != nil { + user.Email = getValidatedEmailOrEmpty(a.repos.Contact, user.ID) + + device, err := a.device.CreateDeviceIfNeeded(user.ID, request.Fingerprint.VisitorID, request.Fingerprint.RequestID) + if err != nil && !strings.Contains(err.Error(), "not found") { return resp, common.StringError(err) } - if created || device.ValidatedAt == nil { - go a.verification.SendDeviceVerification(user.ID, device.ID, device.Description) + // Send verification email if device is unknown and user has a validated email + if user.Email != "" && !isDeviceValidated(device) { + go a.verification.SendDeviceVerification(user.ID, user.Email, device.ID, device.Description) return resp, common.StringError(errors.New("unknown device")) } // Create the JWT - jwt, err := a.GenerateJWT(device) + jwt, err := a.GenerateJWT(user.ID, device) if err != nil { return resp, common.StringError(err) } - return UserCreateResponse{JWT: jwt, User: user}, nil -} -func (a auth) createDeviceIfNeeded(userID, visitorID, requestID string) (bool, model.Device, error) { - device, err := a.repos.Device.GetByUserIdAndFingerprint(userID, visitorID) - if err == nil { - return false, device, nil - } - // create device only if the error is not found - if err != nil && err == repository.ErrNotFound { - visitor, fpErr := a.fingerprint.GetVisitor(visitorID, requestID) - if fpErr != nil { - return false, model.Device{}, common.StringError(fpErr) - } - device, dErr := a.createDevice(userID, visitor) - return dErr == nil, device, dErr + // Invalidate device if it is unknown and was validated so it cannot be used again + err = a.device.InvalidateUnknownDevice(device) + if err != nil { + return resp, common.StringError(err) } - return false, device, common.StringError(err) -} - -func (a auth) createDevice(userID string, visitor model.FPVisitor) (model.Device, error) { - return a.repos.Device.Create(model.Device{ - UserID: userID, - Fingerprint: visitor.VisitorID, - Type: visitor.Type, - IpAddresses: pq.StringArray{visitor.IPAddress}, - Description: visitor.UserAgent, - LastUsedAt: time.Now(), - ValidatedAt: nil, - }) + return UserCreateResponse{JWT: jwt, User: user}, nil } // GenerateJWT generates a jwt token and a refresh token which is saved on redis -func (a auth) GenerateJWT(m model.Device) (JWT, error) { +func (a auth) GenerateJWT(userId string, m ...model.Device) (JWT, error) { claims := JWTClaims{} refreshToken := uuidWithoutHyphens() t := &JWT{ @@ -164,8 +143,12 @@ func (a auth) GenerateJWT(m model.Device) (JWT, error) { ExpAt: time.Now().Add(time.Minute * 15), } - claims.DeviceId = m.ID - claims.UserId = m.UserID + // set device id if available + if len(m) > 0 { + claims.DeviceId = m[0].ID + } + + claims.UserId = userId claims.ExpiresAt = t.ExpAt.Unix() claims.IssuedAt = t.IssuedAt.Unix() // replace this signing method with RSA or something similar @@ -177,7 +160,7 @@ func (a auth) GenerateJWT(m model.Device) (JWT, error) { t.Token = signed // create and save - refreshObj, err := a.repos.Auth.CreateJWTRefresh(common.ToSha256(refreshToken), m.UserID) + refreshObj, err := a.repos.Auth.CreateJWTRefresh(common.ToSha256(refreshToken), userId) if err != nil { return *t, err } @@ -240,7 +223,7 @@ func (a auth) RefreshToken(refreshToken string, walletAddress string) (UserCreat } // create new jwt - jwt, err := a.GenerateJWT(device) + jwt, err := a.GenerateJWT(userId, device) if err != nil { return resp, common.StringError(err) } @@ -256,6 +239,9 @@ func (a auth) RefreshToken(refreshToken string, walletAddress string) (UserCreat if err != nil { return resp, common.StringError(err) } + + // get email + user.Email = getValidatedEmailOrEmpty(a.repos.Contact, user.ID) resp.User = user return resp, nil @@ -295,3 +281,12 @@ func uuidWithoutHyphens() string { s := uuid.New().String() return strings.Replace(s, "-", "", -1) } + +func getValidatedEmailOrEmpty(contactRepo repository.Contact, userId string) string { + contact, err := contactRepo.GetByUserIdAndStatus(userId, "validated") + if err != nil { + return "" + } + + return contact.Data +} diff --git a/pkg/service/base.go b/pkg/service/base.go index c532551c..ee055faf 100644 --- a/pkg/service/base.go +++ b/pkg/service/base.go @@ -13,4 +13,5 @@ type Services struct { Transaction Transaction User User Verification Verification + Device Device } diff --git a/pkg/service/device.go b/pkg/service/device.go new file mode 100644 index 00000000..de5ca877 --- /dev/null +++ b/pkg/service/device.go @@ -0,0 +1,130 @@ +package service + +import ( + "os" + "time" + + "github.com/String-xyz/string-api/pkg/internal/common" + "github.com/String-xyz/string-api/pkg/model" + "github.com/String-xyz/string-api/pkg/repository" + "github.com/lib/pq" + "github.com/pkg/errors" +) + +type Device interface { + VerifyDevice(encrypted string) error + CreateDeviceIfNeeded(userID, visitorID, requestID string) (model.Device, error) + CreateUnknownDevice(userID string) (model.Device, error) + InvalidateUnknownDevice(device model.Device) error +} + +type device struct { + repos repository.Repositories + fingerprint Fingerprint +} + +func NewDevice(repos repository.Repositories, f Fingerprint) Device { + return &device{repos, f} +} + +func (d device) createDevice(userID string, visitor model.FPVisitor, description string) (model.Device, error) { + return d.repos.Device.Create(model.Device{ + UserID: userID, + Fingerprint: visitor.VisitorID, + Type: visitor.Type, + IpAddresses: pq.StringArray{visitor.IPAddress}, + Description: description, + LastUsedAt: time.Now(), + }) +} + +func (d device) CreateUnknownDevice(userID string) (model.Device, error) { + visitor := model.FPVisitor{ + VisitorID: "unknown", + Type: "unknown", + IPAddress: "unknown", + UserAgent: "unknown", + } + device, err := d.createDevice(userID, visitor, "an unknown device") + return device, common.StringError(err) +} + +func (d device) CreateDeviceIfNeeded(userID, visitorID, requestID string) (model.Device, error) { + if visitorID == "" || requestID == "" { + /* fingerprint is not available, create an unknown device. It should be invalidated on every login */ + device, err := d.getOrCreateUnknownDevice(userID, "unknown") + if err != nil { + return device, common.StringError(err) + } + + if !isDeviceValidated(device) { + device.ValidatedAt = nil + return device, nil + } + + return device, common.StringError(err) + } else { + /* device recognized, create or get the device */ + device, err := d.repos.Device.GetByUserIdAndFingerprint(userID, visitorID) + if err == nil { + return device, err + } + + /* create device only if the error is not found */ + if err == repository.ErrNotFound { + visitor, fpErr := d.fingerprint.GetVisitor(visitorID, requestID) + if fpErr != nil { + return model.Device{}, common.StringError(fpErr) + } + device, dErr := d.createDevice(userID, visitor, "a new device "+visitor.UserAgent+" ") + return device, dErr + } + + return device, common.StringError(err) + } +} + +func (d device) VerifyDevice(encrypted string) error { + key := os.Getenv("STRING_ENCRYPTION_KEY") + received, err := common.Decrypt[DeviceVerification](encrypted, key) + if err != nil { + return common.StringError(err) + } + + now := time.Now() + if now.Unix()-received.Timestamp > (60 * 15) { + return common.StringError(errors.New("link expired")) + } + err = d.repos.Device.Update(received.DeviceID, model.DeviceUpdates{ValidatedAt: &now}) + return err +} + +func (d device) getOrCreateUnknownDevice(userId, visitorId string) (model.Device, error) { + var device model.Device + + device, err := d.repos.Device.GetByUserIdAndFingerprint(userId, "unknown") + if err != nil && err != repository.ErrNotFound { + return device, common.StringError(err) + } + + if device.ID != "" { + return device, nil + } + + // if device is not found, create a new one + device, err = d.CreateUnknownDevice(userId) + return device, common.StringError(err) +} + +func isDeviceValidated(device model.Device) bool { + return device.ValidatedAt != nil && !device.ValidatedAt.IsZero() +} + +func (d device) InvalidateUnknownDevice(device model.Device) error { + if device.Fingerprint != "unknown" { + return nil // only unknown devices can be invalidated + } + + device.ValidatedAt = &time.Time{} // Zero time to set it to nil + return d.repos.Device.Update(device.ID, device) +} diff --git a/pkg/service/user.go b/pkg/service/user.go index f30e53e1..1d205ac1 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -93,12 +93,38 @@ func (u user) Create(request model.WalletSignaturePayloadSigned) (UserCreateResp return resp, common.StringError(err) } - user, device, err := u.createUserData(addr, request.Fingerprint.VisitorID, request.Fingerprint.RequestID) + user, err := u.createUserData(addr) if err != nil { return resp, err } - jwt, err := u.auth.GenerateJWT(device) + var device model.Device + + // create device only if there is a visitor + visitorID := request.Fingerprint.VisitorID + requestID := request.Fingerprint.RequestID + if visitorID != "" && requestID != "" { + visitor, err := u.fingerprint.GetVisitor(visitorID, requestID) + if err == nil { + // if fingerprint successfully retrieved, create device, otherwise continue without device + now := time.Now() + + device, err = u.repos.Device.Create(model.Device{ + Fingerprint: visitorID, + UserID: user.ID, + Type: visitor.Type, + IpAddresses: pq.StringArray{visitor.IPAddress}, + Description: visitor.UserAgent, + LastUsedAt: now, + ValidatedAt: &now, + }) + if err != nil { + return resp, common.StringError(err) + } + } + } + + jwt, err := u.auth.GenerateJWT(user.ID, device) if err != nil { return resp, common.StringError(err) } @@ -109,7 +135,7 @@ func (u user) Create(request model.WalletSignaturePayloadSigned) (UserCreateResp return UserCreateResponse{JWT: jwt, User: user}, nil } -func (u user) createUserData(addr, visitorID, requestID string) (model.User, model.Device, error) { +func (u user) createUserData(addr string) (model.User, error) { tx := u.repos.User.MustBegin() u.repos.Instrument.SetTx(tx) u.repos.Device.SetTx(tx) @@ -121,41 +147,21 @@ func (u user) createUserData(addr, visitorID, requestID string) (model.User, mod user, err := u.repos.User.Create(user) if err != nil { u.repos.User.Rollback() - return user, model.Device{}, common.StringError(err) + return user, common.StringError(err) } // Create a new wallet instrument and associate it with the new user instrument := model.Instrument{Type: "crypto-wallet", Status: "verified", Network: "EVM", PublicKey: addr, UserID: user.ID} instrument, err = u.repos.Instrument.Create(instrument) if err != nil { u.repos.Instrument.Rollback() - return user, model.Device{}, common.StringError(err) - } - - visitor, err := u.fingerprint.GetVisitor(visitorID, requestID) - if err != nil { - u.repos.Instrument.Rollback() - return user, model.Device{}, err // is this intentionally not common.StringError? - } - now := time.Now() - device, err := u.repos.Device.Create(model.Device{ - Fingerprint: visitorID, - UserID: user.ID, - Type: visitor.Type, - IpAddresses: pq.StringArray{visitor.IPAddress}, - Description: visitor.UserAgent, - LastUsedAt: now, - ValidatedAt: &now, - }) - if err != nil { - u.repos.Device.Rollback() - return user, model.Device{}, err // is this intentionally not common.StringError? + return user, common.StringError(err) } if err := u.repos.User.Commit(); err != nil { - return user, model.Device{}, common.StringError(errors.New("error commiting transaction")) + return user, common.StringError(errors.New("error commiting transaction")) } - return user, device, nil + return user, nil } func (u user) Update(userID string, request UserUpdates) (model.User, error) { diff --git a/pkg/service/verification.go b/pkg/service/verification.go index c9751ed1..e91059dd 100644 --- a/pkg/service/verification.go +++ b/pkg/service/verification.go @@ -34,8 +34,7 @@ type Verification interface { // VerifyEmail verifies the provided email and creates a contact VerifyEmail(encrypted string) error - SendDeviceVerification(userID string, deviceID string, deviceDescription string) error - VerifyDevice(encrypted string) error + SendDeviceVerification(userID, email string, deviceID string, deviceDescription string) error } type verification struct { @@ -109,13 +108,8 @@ func (v verification) SendEmailVerification(userID, email string) error { return common.StringError(errors.New("link expired")) } -func (v verification) SendDeviceVerification(userID, deviceID, deviceDescription string) error { - email, err := v.repos.Contact.GetByUserIdAndStatus(userID, "validated") - if err != nil { - log.Err(err).Msg("Error getting a valid email") - return err - } - log.Info().Str("email", email.Data) +func (v verification) SendDeviceVerification(userID, email, deviceID, deviceDescription string) error { + log.Info().Str("email", email) key := os.Getenv("STRING_ENCRYPTION_KEY") code, err := common.Encrypt(DeviceVerification{Timestamp: time.Now().Unix(), DeviceID: deviceID, UserID: userID}, key) if err != nil { @@ -126,11 +120,13 @@ func (v verification) SendDeviceVerification(userID, deviceID, deviceDescription baseURL := common.GetBaseURL() from := mail.NewEmail("String XYZ", "auth@string.xyz") subject := "New Device Login Verification" - to := mail.NewEmail("New Device Login", email.Data) + to := mail.NewEmail("New Device Login", email) link := baseURL + "verification?type=device&token=" + code - htmlContent := fmt.Sprintf(`
We noticed that you attempted to log in from a new device %s. Is this you?
+ + textContent := "We noticed that you attempted to log in from " + deviceDescription + " at " + time.Now().Local().Format(time.RFC1123) + ". Is this you?" + htmlContent := fmt.Sprintf(`
%s
Yes`, - deviceDescription, link) + textContent, link) message := mail.NewSingleEmail(from, subject, to, "", htmlContent) client := sendgrid.NewSendClient(os.Getenv("SENDGRID_API_KEY")) @@ -168,17 +164,3 @@ func (v verification) VerifyEmail(encrypted string) error { return nil } - -func (v verification) VerifyDevice(encrypted string) error { - key := os.Getenv("STRING_ENCRYPTION_KEY") - received, err := common.Decrypt[DeviceVerification](encrypted, key) - if err != nil { - return common.StringError(err) - } - now := time.Now() - if now.Unix()-received.Timestamp > (60 * 15) { - return common.StringError(errors.New("link expired")) - } - err = v.repos.Device.Update(received.DeviceID, model.DeviceUpdates{ValidatedAt: &now}) - return err -}