refactor nkode-core
This commit is contained in:
123
pkg/nkode-core/security/jwt_claims.go
Normal file
123
pkg/nkode-core/security/jwt_claims.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"go-nkode/config"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthenticationTokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type ResetNKodeClaims struct {
|
||||
Reset bool `json:"reset"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
const (
|
||||
accessTokenExp = 5 * time.Minute
|
||||
refreshTokenExp = 30 * 24 * time.Hour
|
||||
resetNKodeTokenExp = 5 * time.Minute
|
||||
)
|
||||
|
||||
var secret = getJwtSecret()
|
||||
|
||||
func getJwtSecret() []byte {
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Fatal("No JWT_SECRET found")
|
||||
}
|
||||
|
||||
jwtBytes, err := ParseHexString(jwtSecret)
|
||||
if err != nil {
|
||||
log.Fatalf("error parsing jwt secret %v", err)
|
||||
}
|
||||
return jwtBytes
|
||||
}
|
||||
|
||||
func NewAuthenticationTokens(username string, customerId uuid.UUID) (AuthenticationTokens, error) {
|
||||
accessClaims := NewAccessClaim(username, customerId)
|
||||
|
||||
refreshClaims := jwt.RegisteredClaims{
|
||||
Subject: username,
|
||||
Issuer: customerId.String(),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExp)),
|
||||
}
|
||||
|
||||
accessJwt, err := EncodeAndSignClaims(accessClaims)
|
||||
if err != nil {
|
||||
return AuthenticationTokens{}, err
|
||||
}
|
||||
refreshJwt, err := EncodeAndSignClaims(refreshClaims)
|
||||
|
||||
if err != nil {
|
||||
return AuthenticationTokens{}, err
|
||||
}
|
||||
return AuthenticationTokens{
|
||||
AccessToken: accessJwt,
|
||||
RefreshToken: refreshJwt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewAccessClaim(username string, customerId uuid.UUID) jwt.RegisteredClaims {
|
||||
return jwt.RegisteredClaims{
|
||||
Subject: username,
|
||||
Issuer: customerId.String(),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokenExp)),
|
||||
}
|
||||
}
|
||||
|
||||
func EncodeAndSignClaims(claims jwt.Claims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
func ParseRegisteredClaimToken(token string) (*jwt.RegisteredClaims, error) {
|
||||
return parseJwt[*jwt.RegisteredClaims](token, &jwt.RegisteredClaims{})
|
||||
}
|
||||
|
||||
func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) {
|
||||
return parseJwt[*ResetNKodeClaims](resetNKodeToken, &ResetNKodeClaims{})
|
||||
}
|
||||
|
||||
func parseJwt[T *ResetNKodeClaims | *jwt.RegisteredClaims](tokenStr string, claim jwt.Claims) (T, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, claim, func(token *jwt.Token) (interface{}, error) {
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("error parsing refresh token: %v", err)
|
||||
return nil, config.ErrInvalidJwt
|
||||
}
|
||||
claims, ok := token.Claims.(T)
|
||||
if !ok {
|
||||
return nil, config.ErrInvalidJwt
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func ClaimExpired(claims jwt.RegisteredClaims) error {
|
||||
if claims.ExpiresAt == nil {
|
||||
return config.ErrClaimExpOrNil
|
||||
}
|
||||
if claims.ExpiresAt.Time.After(time.Now()) {
|
||||
return nil
|
||||
}
|
||||
return config.ErrClaimExpOrNil
|
||||
}
|
||||
|
||||
func ResetNKodeToken(userEmail string, customerId uuid.UUID) (string, error) {
|
||||
resetClaims := ResetNKodeClaims{
|
||||
true,
|
||||
jwt.RegisteredClaims{
|
||||
Subject: userEmail,
|
||||
Issuer: customerId.String(),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(resetNKodeTokenExp)),
|
||||
},
|
||||
}
|
||||
return EncodeAndSignClaims(resetClaims)
|
||||
}
|
||||
28
pkg/nkode-core/security/jwt_claims_test.go
Normal file
28
pkg/nkode-core/security/jwt_claims_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJwtClaims(t *testing.T) {
|
||||
email := "testing@example.com"
|
||||
customerId := uuid.New()
|
||||
authTokens, err := NewAuthenticationTokens(email, customerId)
|
||||
assert.NoError(t, err)
|
||||
accessToken, err := ParseRegisteredClaimToken(authTokens.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken.Subject, email)
|
||||
assert.NoError(t, ClaimExpired(*accessToken))
|
||||
refreshToken, err := ParseRegisteredClaimToken(authTokens.RefreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, refreshToken.Subject, email)
|
||||
assert.NoError(t, ClaimExpired(*refreshToken))
|
||||
resetNKode, err := ResetNKodeToken(email, customerId)
|
||||
assert.NoError(t, err)
|
||||
resetToken, err := ParseRestNKodeToken(resetNKode)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, resetToken.Reset)
|
||||
assert.Equal(t, resetToken.Subject, email)
|
||||
}
|
||||
289
pkg/nkode-core/security/util.go
Normal file
289
pkg/nkode-core/security/util.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"github.com/DonovanKelly/sugar-n-spice/set"
|
||||
"log"
|
||||
"math/big"
|
||||
r "math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFisherYatesShuffle = errors.New("unable to shuffle array")
|
||||
ErrRandomBytes = errors.New("random bytes error")
|
||||
ErrRandNonRepeatingUint64 = errors.New("list length must be less than 2^32")
|
||||
ErrParseHexString = errors.New("parse hex string error")
|
||||
ErrMatrixTranspose = errors.New("matrix cannot be nil or empty")
|
||||
ErrListToMatrixNotDivisible = errors.New("list to matrix not possible")
|
||||
ErrElementNotInArray = errors.New("element not in array")
|
||||
ErrDecodeBase64Str = errors.New("decode base64 err")
|
||||
ErrRandNonRepeatingInt = errors.New("list length must be less than 2^31")
|
||||
ErrXorLengthMismatch = errors.New("xor length mismatch")
|
||||
)
|
||||
|
||||
func fisherYatesShuffle[T any](b *[]T) error {
|
||||
for i := len(*b) - 1; i > 0; i-- {
|
||||
bigJ, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
||||
if err != nil {
|
||||
log.Print("fisher yates shuffle error: ", err)
|
||||
return ErrFisherYatesShuffle
|
||||
}
|
||||
j := bigJ.Int64()
|
||||
(*b)[i], (*b)[j] = (*b)[j], (*b)[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FisherYatesShuffle[T any](b *[]T) error {
|
||||
return fisherYatesShuffle(b)
|
||||
}
|
||||
|
||||
func RandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
log.Print("error in random bytes: ", err)
|
||||
return nil, ErrRandomBytes
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func RandomPermutation(n int) ([]int, error) {
|
||||
perm := IdentityArray(n)
|
||||
err := fisherYatesShuffle(&perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return perm, nil
|
||||
}
|
||||
|
||||
func GenerateRandomUInt64() (uint64, error) {
|
||||
randBytes, err := RandomBytes(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
val := binary.LittleEndian.Uint64(randBytes)
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func GenerateRandomInt() (int, error) {
|
||||
randBytes, err := RandomBytes(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
val := int(binary.LittleEndian.Uint64(randBytes) & 0x7FFFFFFFFFFFFFFF) // Ensure it's positive
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) {
|
||||
if listLen > int(1)<<32 {
|
||||
return nil, ErrRandNonRepeatingUint64
|
||||
}
|
||||
listSet := make(set.Set[uint64])
|
||||
for {
|
||||
if listSet.Size() == listLen {
|
||||
break
|
||||
}
|
||||
randNum, err := GenerateRandomUInt64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listSet.Add(randNum)
|
||||
}
|
||||
|
||||
data := listSet.ToSlice()
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) {
|
||||
if listLen > int(1)<<31 {
|
||||
return nil, ErrRandNonRepeatingInt
|
||||
}
|
||||
listSet := make(set.Set[int])
|
||||
for {
|
||||
if listSet.Size() == listLen {
|
||||
break
|
||||
}
|
||||
randNum, err := GenerateRandomInt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listSet.Add(randNum)
|
||||
}
|
||||
|
||||
data := listSet.ToSlice()
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) {
|
||||
if len(l0) != len(l1) {
|
||||
log.Printf("list len mismatch %d, %d", len(l0), len(l1))
|
||||
return nil, ErrXorLengthMismatch
|
||||
}
|
||||
|
||||
xorList := make([]uint64, len(l0))
|
||||
for idx := 0; idx < len(l0); idx++ {
|
||||
xorList[idx] = l0[idx] ^ l1[idx]
|
||||
}
|
||||
return xorList, nil
|
||||
}
|
||||
|
||||
func EncodeBase64Str(data []uint64) string {
|
||||
dataBytes := Uint64ArrToByteArr(data)
|
||||
encoded := base64.StdEncoding.EncodeToString(dataBytes)
|
||||
return encoded
|
||||
}
|
||||
|
||||
func DecodeBase64Str(encoded string) ([]uint64, error) {
|
||||
dataBytes, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
log.Print("error decoding base64 str: ", err)
|
||||
return nil, ErrDecodeBase64Str
|
||||
}
|
||||
data := ByteArrToUint64Arr(dataBytes)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func Uint64ArrToByteArr(intArr []uint64) []byte {
|
||||
byteArr := make([]byte, len(intArr)*8)
|
||||
for idx, val := range intArr {
|
||||
startIdx := idx * 8
|
||||
endIdx := (idx + 1) * 8
|
||||
binary.LittleEndian.PutUint64(byteArr[startIdx:endIdx], val)
|
||||
}
|
||||
return byteArr
|
||||
}
|
||||
|
||||
func IntArrToByteArr(intArr []int) []byte {
|
||||
byteArr := make([]byte, len(intArr)*4)
|
||||
for idx, val := range intArr {
|
||||
uval := uint32(val)
|
||||
startIdx := idx * 4
|
||||
endIdx := (idx + 1) * 4
|
||||
binary.LittleEndian.PutUint32(byteArr[startIdx:endIdx], uval)
|
||||
}
|
||||
return byteArr
|
||||
}
|
||||
|
||||
func ByteArrToUint64Arr(byteArr []byte) []uint64 {
|
||||
intArr := make([]uint64, len(byteArr)/8)
|
||||
for idx := 0; idx < len(intArr); idx++ {
|
||||
startIdx := idx * 8
|
||||
endIdx := (idx + 1) * 8
|
||||
intArr[idx] = binary.LittleEndian.Uint64(byteArr[startIdx:endIdx])
|
||||
}
|
||||
return intArr
|
||||
}
|
||||
|
||||
func ByteArrToIntArr(byteArr []byte) []int {
|
||||
intArr := make([]int, len(byteArr)/4)
|
||||
for idx := 0; idx < len(intArr); idx++ {
|
||||
startIdx := idx * 4
|
||||
endIdx := (idx + 1) * 4
|
||||
uval := binary.LittleEndian.Uint32(byteArr[startIdx:endIdx])
|
||||
intArr[idx] = int(uval)
|
||||
}
|
||||
return intArr
|
||||
}
|
||||
|
||||
func IndexOf[T uint64 | int](arr []T, el T) (int, error) {
|
||||
for idx, val := range arr {
|
||||
if val == el {
|
||||
return idx, nil
|
||||
}
|
||||
}
|
||||
return -1, ErrElementNotInArray
|
||||
}
|
||||
|
||||
func IdentityArray(arrLen int) []int {
|
||||
identityArr := make([]int, arrLen)
|
||||
|
||||
for idx := range identityArr {
|
||||
identityArr[idx] = idx
|
||||
}
|
||||
return identityArr
|
||||
}
|
||||
|
||||
func ListToMatrix(listArr []int, numbCols int) ([][]int, error) {
|
||||
if len(listArr)%numbCols != 0 {
|
||||
log.Printf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols)
|
||||
return nil, ErrListToMatrixNotDivisible
|
||||
}
|
||||
numbRows := len(listArr) / numbCols
|
||||
matrix := make([][]int, numbRows)
|
||||
for idx := range matrix {
|
||||
startIdx := idx * numbCols
|
||||
endIdx := (idx + 1) * numbCols
|
||||
matrix[idx] = listArr[startIdx:endIdx]
|
||||
}
|
||||
return matrix, nil
|
||||
}
|
||||
|
||||
func MatrixTranspose(matrix [][]int) ([][]int, error) {
|
||||
if matrix == nil || len(matrix) == 0 {
|
||||
log.Print("can't transpose nil or zero len matrix")
|
||||
return nil, ErrMatrixTranspose
|
||||
}
|
||||
|
||||
rows := len(matrix)
|
||||
cols := len((matrix)[0])
|
||||
|
||||
// Check if the matrix is not rectangular
|
||||
for _, row := range matrix {
|
||||
if len(row) != cols {
|
||||
log.Print("all rows must have the same number of columns")
|
||||
return nil, ErrMatrixTranspose
|
||||
}
|
||||
}
|
||||
|
||||
transposed := make([][]int, cols)
|
||||
for i := range transposed {
|
||||
transposed[i] = make([]int, rows)
|
||||
}
|
||||
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < cols; j++ {
|
||||
transposed[j][i] = (matrix)[i][j]
|
||||
}
|
||||
}
|
||||
|
||||
return transposed, nil
|
||||
}
|
||||
|
||||
func MatrixToList(matrix [][]int) []int {
|
||||
var flat []int
|
||||
for _, row := range matrix {
|
||||
flat = append(flat, row...)
|
||||
}
|
||||
return flat
|
||||
}
|
||||
|
||||
func Choice[T any](items []T) T {
|
||||
r.Seed(time.Now().UnixNano()) // Seed the random number generator
|
||||
return items[r.Intn(len(items))]
|
||||
}
|
||||
|
||||
// GenerateRandomString creates a random string of a specified length.
|
||||
func GenerateRandomString(length int) string {
|
||||
charset := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
|
||||
b := make([]rune, length)
|
||||
for i := range b {
|
||||
b[i] = Choice[rune](charset)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func ParseHexString(hexStr string) ([]byte, error) {
|
||||
// Decode the hex string into bytes
|
||||
bytes, err := hex.DecodeString(hexStr)
|
||||
if err != nil {
|
||||
log.Print("parse hex string err: ", err)
|
||||
return nil, ErrParseHexString
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
62
pkg/nkode-core/security/util_test.go
Normal file
62
pkg/nkode-core/security/util_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateRandomNonRepeatingUint64(t *testing.T) {
|
||||
arrLen := 100000
|
||||
randNumbs, err := GenerateRandomNonRepeatingUint64(arrLen)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(randNumbs), arrLen)
|
||||
}
|
||||
|
||||
func TestGenerateRandomNonRepeatingInt(t *testing.T) {
|
||||
arrLen := 100000
|
||||
randNumbs, err := GenerateRandomNonRepeatingInt(arrLen)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(randNumbs), arrLen)
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
testArr := []uint64{1, 2, 3, 4, 5, 6}
|
||||
testEncode := EncodeBase64Str(testArr)
|
||||
testDecode, err := DecodeBase64Str(testEncode)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(testArr), len(testDecode))
|
||||
for idx, val := range testArr {
|
||||
assert.Equal(t, val, testDecode[idx])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatrixTranspose(t *testing.T) {
|
||||
matrix := [][]int{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
}
|
||||
expectedMatrixT := [][]int{
|
||||
{0, 3},
|
||||
{1, 4},
|
||||
{2, 5},
|
||||
}
|
||||
expectedFlatMat := MatrixToList(expectedMatrixT)
|
||||
matrixT, err := MatrixTranspose(matrix)
|
||||
assert.NoError(t, err)
|
||||
flatMat := MatrixToList(matrixT)
|
||||
|
||||
assert.Equal(t, len(flatMat), len(expectedFlatMat))
|
||||
for idx := range flatMat {
|
||||
assert.Equal(t, expectedFlatMat[idx], flatMat[idx])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntToByteAndBack(t *testing.T) {
|
||||
origIntArr := []int{1, 2, 3, 4, 5}
|
||||
byteArr := IntArrToByteArr(origIntArr)
|
||||
intArr := ByteArrToIntArr(byteArr)
|
||||
|
||||
assert.ElementsMatch(t, origIntArr, intArr)
|
||||
}
|
||||
Reference in New Issue
Block a user