implement jwt claims

This commit is contained in:
2024-09-23 11:18:13 -05:00
parent 2b3abb8fb2
commit f6e9ee7b1a
10 changed files with 175 additions and 26 deletions

View File

@@ -76,6 +76,10 @@ func (db *InMemoryDb) UpdateUserInterface(userId UserId, ui UserInterface) error
return nil return nil
} }
func (db *InMemoryDb) UpdateUserRefreshToken(userId UserId, refreshToken string) error {
return nil
}
func (db *InMemoryDb) Renew(id CustomerId) error { func (db *InMemoryDb) Renew(id CustomerId) error {
customer, exists := db.Customers[id] customer, exists := db.Customers[id]
if !exists { if !exists {
@@ -96,7 +100,7 @@ func (db *InMemoryDb) Renew(id CustomerId) error {
return nil return nil
} }
func (db *InMemoryDb) RefreshUser(user User, passocode []int, customerAttr CustomerAttributes) error { func (db *InMemoryDb) RefreshUserPasscode(user User, passocode []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passocode, customerAttr) err := user.RefreshPasscode(passocode, customerAttr)
if err != nil { if err != nil {
return err return err
@@ -117,7 +121,7 @@ func (db *InMemoryDb) RandomSvgIdxInterface(kp KeypadDimension) (SvgIdInterface,
return svgs, nil return svgs, nil
} }
func (db InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) { func (db *InMemoryDb) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) {
return make([]string, len(idxs)), nil return make([]string, len(idxs)), nil
} }

93
core/jwt_claims.go Normal file
View File

@@ -0,0 +1,93 @@
package core
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"time"
)
type JwtTokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
const (
accessTokenExp = 5 * time.Minute
refreshTokenExp = 30 * 24 * time.Hour
)
var secret = []byte("your-secret-key")
func NewJwtTokens(username string) (JwtTokens, error) {
accessClaims := NewAccessClaim(username)
refreshClaims := jwt.RegisteredClaims{
Subject: username,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExp)),
}
accessJwt, err := EncodeAndSignClaims(accessClaims)
if err != nil {
return JwtTokens{}, err
}
refreshJwt, err := EncodeAndSignClaims(refreshClaims)
if err != nil {
return JwtTokens{}, err
}
return JwtTokens{
AccessToken: accessJwt,
RefreshToken: refreshJwt,
}, nil
}
func NewAccessClaim(username string) jwt.RegisteredClaims {
return jwt.RegisteredClaims{
Subject: username,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenExp)),
}
}
func EncodeAndSignClaims(claims jwt.RegisteredClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
}
func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
}
func ClaimExpired(claims jwt.RegisteredClaims) error {
if claims.ExpiresAt == nil {
return errors.New("claim exp is nil")
}
if claims.ExpiresAt.Time.Before(time.Now()) {
return nil
}
return errors.New("claim expired")
}

View File

@@ -126,27 +126,35 @@ func (n *NKodeAPI) GetLoginInterface(username Username, customerId CustomerId) (
return &resp, nil return &resp, nil
} }
func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (string, error) { func (n *NKodeAPI) Login(customerId CustomerId, username Username, keySelection KeySelection) (*JwtTokens, error) {
customer, err := n.Db.GetCustomer(customerId) customer, err := n.Db.GetCustomer(customerId)
if err != nil { if err != nil {
return "", err return nil, err
} }
user, err := n.Db.GetUser(username, customerId) user, err := n.Db.GetUser(username, customerId)
if err != nil { if err != nil {
return "", errors.New(fmt.Sprintf("user dne %s", username)) return nil, errors.New(fmt.Sprintf("user dne %s", username))
} }
passcode, err := ValidKeyEntry(*user, *customer, keySelection) passcode, err := ValidKeyEntry(*user, *customer, keySelection)
if err != nil { if err != nil {
return "", err return nil, err
} }
if user.Renew { if user.Renew {
err = n.Db.RefreshUser(*user, passcode, customer.Attributes) err = n.Db.RefreshUserPasscode(*user, passcode, customer.Attributes)
if err != nil { if err != nil {
return "", err return nil, err
} }
} }
return "", nil jwtToken, err := NewJwtTokens(string(user.Username))
if err != nil {
return nil, err
}
err = n.Db.UpdateUserRefreshToken(user.Id, jwtToken.RefreshToken)
if err != nil {
return nil, err
}
return &jwtToken, nil
} }
func (n *NKodeAPI) RenewAttributes(customerId CustomerId) error { func (n *NKodeAPI) RenewAttributes(customerId CustomerId) error {
@@ -161,6 +169,21 @@ func (n *NKodeAPI) GetSvgStringInterface(svgId SvgIdInterface) ([]string, error)
return n.Db.GetSvgStringInterface(svgId) return n.Db.GetSvgStringInterface(svgId)
} }
func (n *NKodeAPI) RefreshToken(jwt string) (string, error) { func (n *NKodeAPI) RefreshToken(username Username, customerId CustomerId, refreshToken string) (string, error) {
return "", nil user, err := n.Db.GetUser(username, customerId)
if err != nil {
return "", err
}
if user.RefreshToken != refreshToken {
return "", errors.New("refresh token is invalid")
}
refreshClaims, err := ParseRefreshToken(refreshToken)
if err != nil {
return "", err
}
if err = ClaimExpired(*refreshClaims); err != nil {
return "", err
}
newAccessClaims := NewAccessClaim(string(username))
return EncodeAndSignClaims(newAccessClaims)
} }

View File

@@ -286,7 +286,20 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
log.Println(err) log.Println(err)
return return
} }
_, err = h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection) jwtTokens, err := h.Api.Login(CustomerId(customerId), loginPost.Username, loginPost.KeySelection)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
respBytes, err := json.Marshal(jwtTokens)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil { if err != nil {
internalServerErrorHandler(w) internalServerErrorHandler(w)
log.Println(err) log.Println(err)

View File

@@ -73,8 +73,8 @@ func (d *SqliteDB) WriteNewUser(u User) error {
} }
}() }()
insertUser := ` insertUser := `
INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
` `
var renew int var renew int
if u.Renew { if u.Renew {
@@ -82,7 +82,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
} else { } else {
renew = 0 renew = 0
} }
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
if err != nil { if err != nil {
return err return err
@@ -136,7 +136,7 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
func (d *SqliteDB) GetUser(username Username, customerId CustomerId) (*User, error) { func (d *SqliteDB) GetUser(username Username, customerId CustomerId) (*User, error) {
userSelect := ` userSelect := `
SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
WHERE user.username = ? AND user.customer_id = ? WHERE user.username = ? AND user.customer_id = ?
` `
rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String()) rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String())
@@ -145,6 +145,7 @@ WHERE user.username = ? AND user.customer_id = ?
} }
var id string var id string
var renewVal int var renewVal int
var refreshToken string
var code string var code string
var mask string var mask string
var attrsPerKey int var attrsPerKey int
@@ -158,7 +159,7 @@ WHERE user.username = ? AND user.customer_id = ?
var idxInterface []byte var idxInterface []byte
var svgIdInterface []byte var svgIdInterface []byte
err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
if rows.Next() { if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId))
} }
@@ -200,7 +201,8 @@ WHERE user.username = ? AND user.customer_id = ?
SvgId: util.ByteArrToIntArr(svgIdInterface), SvgId: util.ByteArrToIntArr(svgIdInterface),
Kp: nil, Kp: nil,
}, },
Renew: renew, Renew: renew,
RefreshToken: refreshToken,
} }
user.Interface.Kp = &user.Kp user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp
@@ -217,6 +219,15 @@ UPDATE user SET idx_interface = ? WHERE id = ?
return err return err
} }
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
updateUserRefreshToken := `
UPDATE user SET refresh_token = ? WHERE id = ?
`
_, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String())
return err
}
func (d *SqliteDB) Renew(id CustomerId) error { func (d *SqliteDB) Renew(id CustomerId) error {
customer, err := d.GetCustomer(id) customer, err := d.GetCustomer(id)
if err != nil { if err != nil {
@@ -276,7 +287,7 @@ COMMIT;
return err return err
} }
func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr) err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil { if err != nil {
return err return err
@@ -284,7 +295,7 @@ func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr Custom
updateUser := ` updateUser := `
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
` `
_, err = d.db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) _, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
return err return err
} }

View File

@@ -96,8 +96,9 @@ type DbAccessor interface {
WriteNewCustomer(Customer) error WriteNewCustomer(Customer) error
WriteNewUser(User) error WriteNewUser(User) error
UpdateUserInterface(UserId, UserInterface) error UpdateUserInterface(UserId, UserInterface) error
UpdateUserRefreshToken(UserId, string) error
Renew(CustomerId) error Renew(CustomerId) error
RefreshUser(User, []int, CustomerAttributes) error RefreshUserPasscode(User, []int, CustomerAttributes) error
RandomSvgInterface(KeypadDimension) ([]string, error) RandomSvgInterface(KeypadDimension) ([]string, error)
RandomSvgIdxInterface(KeypadDimension) (SvgIdInterface, error) RandomSvgIdxInterface(KeypadDimension) (SvgIdInterface, error)
GetSvgStringInterface(SvgIdInterface) ([]string, error) GetSvgStringInterface(SvgIdInterface) ([]string, error)

View File

@@ -16,6 +16,7 @@ type User struct {
CipherKeys UserCipherKeys CipherKeys UserCipherKeys
Interface UserInterface Interface UserInterface
Renew bool Renew bool
RefreshToken string
} }
func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) { func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) {

View File

@@ -9,7 +9,6 @@ import (
) )
func main() { func main() {
//db := nkode.NewInMemoryDb()
db := core.NewSqliteDB("nkode.db") db := core.NewSqliteDB("nkode.db")
defer db.CloseDb() defer db.CloseDb()
nkodeApi := core.NewNKodeAPI(db) nkodeApi := core.NewNKodeAPI(db)

View File

@@ -71,9 +71,12 @@ func TestApi(t *testing.T) {
Username: username, Username: username,
KeySelection: loginKeySelection, KeySelection: loginKeySelection,
} }
var jwtTokens core.JwtTokens
testApiPost(t, base+core.Login, loginBody, nil) testApiPost(t, base+core.Login, loginBody, &jwtTokens)
refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken)
assert.Equal(t, refreshClaims.Subject, string(username))
accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken)
assert.Equal(t, accessClaims.Subject, string(username))
renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId}
testApiPost(t, base+core.RenewAttributes, renewBody, nil) testApiPost(t, base+core.RenewAttributes, renewBody, nil)
@@ -85,7 +88,7 @@ func TestApi(t *testing.T) {
KeySelection: loginKeySelection, KeySelection: loginKeySelection,
} }
testApiPost(t, base+core.Login, loginBody, nil) testApiPost(t, base+core.Login, loginBody, &jwtTokens)
var randomSvgInterfaceResp core.RandomSvgInterfaceResp var randomSvgInterfaceResp core.RandomSvgInterfaceResp
testApiGet(t, base+core.RandomSvgInterface, &randomSvgInterfaceResp) testApiGet(t, base+core.RandomSvgInterface, &randomSvgInterfaceResp)

View File

@@ -162,6 +162,7 @@ CREATE TABLE IF NOT EXISTS user (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
username TEXT NOT NULL, username TEXT NOT NULL,
renew INT NOT NULL, renew INT NOT NULL,
refresh_token TEXT,
customer_id TEXT NOT NULL, customer_id TEXT NOT NULL,
-- Enciphered Passcode -- Enciphered Passcode