diff --git a/core/in_memory_db.go b/core/in_memory_db.go index face3e7..dbd0c76 100644 --- a/core/in_memory_db.go +++ b/core/in_memory_db.go @@ -76,6 +76,10 @@ func (db *InMemoryDb) UpdateUserInterface(userId UserId, ui UserInterface) error return nil } +func (db *InMemoryDb) UpdateUserRefreshToken(userId UserId, refreshToken string) error { + return nil +} + func (db *InMemoryDb) Renew(id CustomerId) error { customer, exists := db.Customers[id] if !exists { @@ -96,7 +100,7 @@ func (db *InMemoryDb) Renew(id CustomerId) error { return nil } -func (db *InMemoryDb) RefreshUser(user User, passocode []int, customerAttr CustomerAttributes) error { +func (db *InMemoryDb) RefreshUserPasscode(user User, passocode []int, customerAttr CustomerAttributes) error { err := user.RefreshPasscode(passocode, customerAttr) if err != nil { return err @@ -117,7 +121,7 @@ func (db *InMemoryDb) RandomSvgIdxInterface(kp KeypadDimension) (SvgIdInterface, return svgs, nil } -func (db InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) { +func (db *InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) { return make([]string, len(idxs)), nil } diff --git a/core/jwt_claims.go b/core/jwt_claims.go new file mode 100644 index 0000000..b3c5089 --- /dev/null +++ b/core/jwt_claims.go @@ -0,0 +1,93 @@ +package core + +import ( + "errors" + "fmt" + "github.com/golang-jwt/jwt/v5" + "time" +) + +type JwtTokens struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +const ( + accessTokenExp = 5 * time.Minute + refreshTokenExp = 30 * 24 * time.Hour +) + +var secret = []byte("your-secret-key") + +func NewJwtTokens(username string) (JwtTokens, error) { + accessClaims := NewAccessClaim(username) + + refreshClaims := jwt.RegisteredClaims{ + Subject: username, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExp)), + } + + accessJwt, err := EncodeAndSignClaims(accessClaims) + if err != nil { + return JwtTokens{}, err + } + refreshJwt, err := EncodeAndSignClaims(refreshClaims) + + if err != nil { + return JwtTokens{}, err + } + return JwtTokens{ + AccessToken: accessJwt, + RefreshToken: refreshJwt, + }, nil +} + +func NewAccessClaim(username string) jwt.RegisteredClaims { + return jwt.RegisteredClaims{ + Subject: username, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenExp)), + } +} + +func EncodeAndSignClaims(claims jwt.RegisteredClaims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(secret) +} + +func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) { + token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return secret, nil + }) + if err != nil { + return nil, fmt.Errorf("error parsing refresh token: %w", err) + } + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok { + return nil, errors.New("unable to parse claims") + } + return claims, nil +} + +func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) { + token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return secret, nil + }) + if err != nil { + return nil, fmt.Errorf("error parsing refresh token: %w", err) + } + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok { + return nil, errors.New("unable to parse claims") + } + return claims, nil +} + +func ClaimExpired(claims jwt.RegisteredClaims) error { + if claims.ExpiresAt == nil { + return errors.New("claim exp is nil") + } + if claims.ExpiresAt.Time.Before(time.Now()) { + return nil + } + return errors.New("claim expired") +} diff --git a/core/nkode_api.go b/core/nkode_api.go index 6ba3ba4..41c343c 100644 --- a/core/nkode_api.go +++ b/core/nkode_api.go @@ -126,27 +126,35 @@ func (n *NKodeAPI) GetLoginInterface(username Username, customerId CustomerId) ( return &resp, nil } -func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (string, error) { +func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (*JwtTokens, error) { customer, err := n.Db.GetCustomer(customerId) if err != nil { - return "", err + return nil, err } user, err := n.Db.GetUser(username, customerId) if err != nil { - return "", errors.New(fmt.Sprintf("user dne %s", username)) + return nil, errors.New(fmt.Sprintf("user dne %s", username)) } passcode, err := ValidKeyEntry(*user, *customer, keySelection) if err != nil { - return "", err + return nil, err } if user.Renew { - err = n.Db.RefreshUser(*user, passcode, customer.Attributes) + err = n.Db.RefreshUserPasscode(*user, passcode, customer.Attributes) if err != nil { - return "", err + return nil, err } } - return "", nil + jwtToken, err := NewJwtTokens(string(user.Username)) + if err != nil { + return nil, err + } + err = n.Db.UpdateUserRefreshToken(user.Id, jwtToken.RefreshToken) + if err != nil { + return nil, err + } + return &jwtToken, nil } func (n *NKodeAPI) RenewAttributes(customerId CustomerId) error { @@ -161,6 +169,21 @@ func (n *NKodeAPI) GetSvgStringInterface(svgId SvgIdInterface) ([]string, error) return n.Db.GetSvgStringInterface(svgId) } -func (n *NKodeAPI) RefreshToken(jwt string) (string, error) { - return "", nil +func (n *NKodeAPI) RefreshToken(username Username, customerId CustomerId, refreshToken string) (string, error) { + user, err := n.Db.GetUser(username, customerId) + if err != nil { + return "", err + } + if user.RefreshToken != refreshToken { + return "", errors.New("refresh token is invalid") + } + refreshClaims, err := ParseRefreshToken(refreshToken) + if err != nil { + return "", err + } + if err = ClaimExpired(*refreshClaims); err != nil { + return "", err + } + newAccessClaims := NewAccessClaim(string(username)) + return EncodeAndSignClaims(newAccessClaims) } diff --git a/core/nkode_handler.go b/core/nkode_handler.go index 0a4d070..a720d7e 100644 --- a/core/nkode_handler.go +++ b/core/nkode_handler.go @@ -286,7 +286,20 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { log.Println(err) return } - _, err = h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection) + jwtTokens, err := h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection) + if err != nil { + internalServerErrorHandler(w) + log.Println(err) + return + } + + respBytes, err := json.Marshal(jwtTokens) + if err != nil { + internalServerErrorHandler(w) + log.Println(err) + return + } + _, err = w.Write(respBytes) if err != nil { internalServerErrorHandler(w) log.Println(err) diff --git a/core/sqlite_db.go b/core/sqlite_db.go index 2e87c17..1265343 100644 --- a/core/sqlite_db.go +++ b/core/sqlite_db.go @@ -73,8 +73,8 @@ func (d *SqliteDB) WriteNewUser(u User) error { } }() insertUser := ` -INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) -VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) +INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) +VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` var renew int if u.Renew { @@ -82,7 +82,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) } else { renew = 0 } - _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) + _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) if err != nil { return err @@ -136,7 +136,7 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { func (d *SqliteDB) GetUser(username Username, customerId CustomerId) (*User, error) { userSelect := ` -SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user +SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user WHERE user.username = ? AND user.customer_id = ? ` rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String()) @@ -145,6 +145,7 @@ WHERE user.username = ? AND user.customer_id = ? } var id string var renewVal int + var refreshToken string var code string var mask string var attrsPerKey int @@ -158,7 +159,7 @@ WHERE user.username = ? AND user.customer_id = ? var idxInterface []byte var svgIdInterface []byte - err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) + err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) if rows.Next() { return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) } @@ -200,7 +201,8 @@ WHERE user.username = ? AND user.customer_id = ? SvgId: util.ByteArrToIntArr(svgIdInterface), Kp: nil, }, - Renew: renew, + Renew: renew, + RefreshToken: refreshToken, } user.Interface.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp @@ -217,6 +219,15 @@ UPDATE user SET idx_interface = ? WHERE id = ? return err } +func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error { + updateUserRefreshToken := ` +UPDATE user SET refresh_token = ? WHERE id = ? +` + _, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String()) + + return err +} + func (d *SqliteDB) Renew(id CustomerId) error { customer, err := d.GetCustomer(id) if err != nil { @@ -276,7 +287,7 @@ COMMIT; return err } -func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { +func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { err := user.RefreshPasscode(passcodeIdx, customerAttr) if err != nil { return err @@ -284,7 +295,7 @@ func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr Custom updateUser := ` UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; ` - _, err = d.db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) + _, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) return err } diff --git a/core/type.go b/core/type.go index 1981cef..5721afc 100644 --- a/core/type.go +++ b/core/type.go @@ -96,8 +96,9 @@ type DbAccessor interface { WriteNewCustomer(Customer) error WriteNewUser(User) error UpdateUserInterface(UserId, UserInterface) error + UpdateUserRefreshToken(UserId, string) error Renew(CustomerId) error - RefreshUser(User, []int, CustomerAttributes) error + RefreshUserPasscode(User, []int, CustomerAttributes) error RandomSvgInterface(KeypadDimension) ([]string, error) RandomSvgIdxInterface(KeypadDimension) (SvgIdInterface, error) GetSvgStringInterface(SvgIdInterface) ([]string, error) diff --git a/core/user.go b/core/user.go index 47751a9..4ffbe75 100644 --- a/core/user.go +++ b/core/user.go @@ -16,6 +16,7 @@ type User struct { CipherKeys UserCipherKeys Interface UserInterface Renew bool + RefreshToken string } func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) { diff --git a/main.go b/main.go index de239d2..be28228 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( ) func main() { - //db := nkode.NewInMemoryDb() db := core.NewSqliteDB("nkode.db") defer db.CloseDb() nkodeApi := core.NewNKodeAPI(db) diff --git a/main_test.go b/main_test.go index d16352b..6201cf6 100644 --- a/main_test.go +++ b/main_test.go @@ -71,9 +71,12 @@ func TestApi(t *testing.T) { Username: username, KeySelection: loginKeySelection, } - - testApiPost(t, base+core.Login, loginBody, nil) - + var jwtTokens core.JwtTokens + testApiPost(t, base+core.Login, loginBody, &jwtTokens) + refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken) + assert.Equal(t, refreshClaims.Subject, string(username)) + accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken) + assert.Equal(t, accessClaims.Subject, string(username)) renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} testApiPost(t, base+core.RenewAttributes, renewBody, nil) @@ -85,7 +88,7 @@ func TestApi(t *testing.T) { KeySelection: loginKeySelection, } - testApiPost(t, base+core.Login, loginBody, nil) + testApiPost(t, base+core.Login, loginBody, &jwtTokens) var randomSvgInterfaceResp core.RandomSvgInterfaceResp testApiGet(t, base+core.RandomSvgInterface, &randomSvgInterfaceResp) diff --git a/sqlite-init/sqlite_init.go b/sqlite-init/sqlite_init.go index 40e74d4..3e0dc51 100644 --- a/sqlite-init/sqlite_init.go +++ b/sqlite-init/sqlite_init.go @@ -162,6 +162,7 @@ CREATE TABLE IF NOT EXISTS user ( id TEXT NOT NULL PRIMARY KEY, username TEXT NOT NULL, renew INT NOT NULL, + refresh_token TEXT, customer_id TEXT NOT NULL, -- Enciphered Passcode