implement jwt claims

This commit is contained in:
2024-09-23 11:18:13 -05:00
parent 2b3abb8fb2
commit f6e9ee7b1a
10 changed files with 175 additions and 26 deletions

View File

@@ -76,6 +76,10 @@ func (db *InMemoryDb) UpdateUserInterface(userId UserId, ui UserInterface) error
return nil
}
func (db *InMemoryDb) UpdateUserRefreshToken(userId UserId, refreshToken string) error {
return nil
}
func (db *InMemoryDb) Renew(id CustomerId) error {
customer, exists := db.Customers[id]
if !exists {
@@ -96,7 +100,7 @@ func (db *InMemoryDb) Renew(id CustomerId) error {
return nil
}
func (db *InMemoryDb) RefreshUser(user User, passocode []int, customerAttr CustomerAttributes) error {
func (db *InMemoryDb) RefreshUserPasscode(user User, passocode []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passocode, customerAttr)
if err != nil {
return err
@@ -117,7 +121,7 @@ func (db *InMemoryDb) RandomSvgIdxInterface(kp KeypadDimension) (SvgIdInterface,
return svgs, nil
}
func (db InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) {
func (db *InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) {
return make([]string, len(idxs)), nil
}

93
core/jwt_claims.go Normal file
View File

@@ -0,0 +1,93 @@
package core
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"time"
)
type JwtTokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
const (
accessTokenExp = 5 * time.Minute
refreshTokenExp = 30 * 24 * time.Hour
)
var secret = []byte("your-secret-key")
func NewJwtTokens(username string) (JwtTokens, error) {
accessClaims := NewAccessClaim(username)
refreshClaims := jwt.RegisteredClaims{
Subject: username,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExp)),
}
accessJwt, err := EncodeAndSignClaims(accessClaims)
if err != nil {
return JwtTokens{}, err
}
refreshJwt, err := EncodeAndSignClaims(refreshClaims)
if err != nil {
return JwtTokens{}, err
}
return JwtTokens{
AccessToken: accessJwt,
RefreshToken: refreshJwt,
}, nil
}
func NewAccessClaim(username string) jwt.RegisteredClaims {
return jwt.RegisteredClaims{
Subject: username,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenExp)),
}
}
func EncodeAndSignClaims(claims jwt.RegisteredClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
}
func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
}
func ClaimExpired(claims jwt.RegisteredClaims) error {
if claims.ExpiresAt == nil {
return errors.New("claim exp is nil")
}
if claims.ExpiresAt.Time.Before(time.Now()) {
return nil
}
return errors.New("claim expired")
}

View File

@@ -126,27 +126,35 @@ func (n *NKodeAPI) GetLoginInterface(username Username, customerId CustomerId) (
return &resp, nil
}
func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (string, error) {
func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (*JwtTokens, error) {
customer, err := n.Db.GetCustomer(customerId)
if err != nil {
return "", err
return nil, err
}
user, err := n.Db.GetUser(username, customerId)
if err != nil {
return "", errors.New(fmt.Sprintf("user dne %s", username))
return nil, errors.New(fmt.Sprintf("user dne %s", username))
}
passcode, err := ValidKeyEntry(*user, *customer, keySelection)
if err != nil {
return "", err
return nil, err
}
if user.Renew {
err = n.Db.RefreshUser(*user, passcode, customer.Attributes)
err = n.Db.RefreshUserPasscode(*user, passcode, customer.Attributes)
if err != nil {
return "", err
return nil, err
}
}
return "", nil
jwtToken, err := NewJwtTokens(string(user.Username))
if err != nil {
return nil, err
}
err = n.Db.UpdateUserRefreshToken(user.Id, jwtToken.RefreshToken)
if err != nil {
return nil, err
}
return &jwtToken, nil
}
func (n *NKodeAPI) RenewAttributes(customerId CustomerId) error {
@@ -161,6 +169,21 @@ func (n *NKodeAPI) GetSvgStringInterface(svgId SvgIdInterface) ([]string, error)
return n.Db.GetSvgStringInterface(svgId)
}
func (n *NKodeAPI) RefreshToken(jwt string) (string, error) {
return "", nil
func (n *NKodeAPI) RefreshToken(username Username, customerId CustomerId, refreshToken string) (string, error) {
user, err := n.Db.GetUser(username, customerId)
if err != nil {
return "", err
}
if user.RefreshToken != refreshToken {
return "", errors.New("refresh token is invalid")
}
refreshClaims, err := ParseRefreshToken(refreshToken)
if err != nil {
return "", err
}
if err = ClaimExpired(*refreshClaims); err != nil {
return "", err
}
newAccessClaims := NewAccessClaim(string(username))
return EncodeAndSignClaims(newAccessClaims)
}

View File

@@ -286,7 +286,20 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
log.Println(err)
return
}
_, err = h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection)
jwtTokens, err := h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
respBytes, err := json.Marshal(jwtTokens)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)

View File

@@ -73,8 +73,8 @@ func (d *SqliteDB) WriteNewUser(u User) error {
}
}()
insertUser := `
INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
var renew int
if u.Renew {
@@ -82,7 +82,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
} else {
renew = 0
}
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
if err != nil {
return err
@@ -136,7 +136,7 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
func (d *SqliteDB) GetUser(username Username, customerId CustomerId) (*User, error) {
userSelect := `
SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
WHERE user.username = ? AND user.customer_id = ?
`
rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String())
@@ -145,6 +145,7 @@ WHERE user.username = ? AND user.customer_id = ?
}
var id string
var renewVal int
var refreshToken string
var code string
var mask string
var attrsPerKey int
@@ -158,7 +159,7 @@ WHERE user.username = ? AND user.customer_id = ?
var idxInterface []byte
var svgIdInterface []byte
err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId))
}
@@ -200,7 +201,8 @@ WHERE user.username = ? AND user.customer_id = ?
SvgId: util.ByteArrToIntArr(svgIdInterface),
Kp: nil,
},
Renew: renew,
Renew: renew,
RefreshToken: refreshToken,
}
user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp
@@ -217,6 +219,15 @@ UPDATE user SET idx_interface = ? WHERE id = ?
return err
}
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
updateUserRefreshToken := `
UPDATE user SET refresh_token = ? WHERE id = ?
`
_, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String())
return err
}
func (d *SqliteDB) Renew(id CustomerId) error {
customer, err := d.GetCustomer(id)
if err != nil {
@@ -276,7 +287,7 @@ COMMIT;
return err
}
func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil {
return err
@@ -284,7 +295,7 @@ func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr Custom
updateUser := `
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
`
_, err = d.db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
_, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
return err
}

View File

@@ -96,8 +96,9 @@ type DbAccessor interface {
WriteNewCustomer(Customer) error
WriteNewUser(User) error
UpdateUserInterface(UserId, UserInterface) error
UpdateUserRefreshToken(UserId, string) error
Renew(CustomerId) error
RefreshUser(User, []int, CustomerAttributes) error
RefreshUserPasscode(User, []int, CustomerAttributes) error
RandomSvgInterface(KeypadDimension) ([]string, error)
RandomSvgIdxInterface(KeypadDimension) (SvgIdInterface, error)
GetSvgStringInterface(SvgIdInterface) ([]string, error)

View File

@@ -16,6 +16,7 @@ type User struct {
CipherKeys UserCipherKeys
Interface UserInterface
Renew bool
RefreshToken string
}
func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) {