refactor errors
This commit is contained in:
62
core/constants.go
Normal file
62
core/constants.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidNKodeLength = errors.New("invalid nKode length")
|
||||
ErrInvalidNKodeIdx = errors.New("invalid passcode attribute index")
|
||||
ErrTooFewDistinctSet = errors.New("too few distinct sets")
|
||||
ErrTooFewDistinctAttributes = errors.New("too few distinct attributes")
|
||||
ErrEmailAlreadySent = errors.New("email already sent")
|
||||
ErrClaimExpOrNil = errors.New("claim expired or nil")
|
||||
ErrInvalidJwt = errors.New("invalid jwt")
|
||||
ErrInvalidKeypadDimensions = errors.New("keypad dimensions out of range")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrSignupSessionDNE = errors.New("signup session does not exist")
|
||||
ErrUserForCustomerDNE = errors.New("user for customer does not exist")
|
||||
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
|
||||
ErrCustomerDne = errors.New("customer dne")
|
||||
ErrSvgDne = errors.New("svg dne")
|
||||
ErrStoppingDatabase = errors.New("stopping database")
|
||||
ErrSqliteTx = errors.New("sqlite begin, exec, query, or commit error. see logs")
|
||||
ErrEmptySvgTable = errors.New("empty svg_icon table")
|
||||
ErrKeyIndexOutOfRange = errors.New("one or more keys is out of range")
|
||||
ErrAttributeIndexOutOfRange = errors.New("attribute index out of range")
|
||||
ErrInternalValidKeyEntry = errors.New("internal validation error")
|
||||
ErrUserMaskTooLong = errors.New("user mask length exceeds max nkode length")
|
||||
ErrInterfaceNotDispersible = errors.New("interface is not dispersible")
|
||||
ErrIncompleteUserSignupSession = errors.New("incomplete user signup session")
|
||||
ErrSetConfirmSignupMismatch = errors.New("set and confirm nkode are not the same")
|
||||
ErrKeypadIsNotDispersible = errors.New("keypad is not dispersible")
|
||||
)
|
||||
|
||||
var HttpErrMap = map[error]int{
|
||||
ErrInvalidNKodeLength: http.StatusBadRequest,
|
||||
ErrInvalidNKodeIdx: http.StatusBadRequest,
|
||||
ErrTooFewDistinctSet: http.StatusBadRequest,
|
||||
ErrTooFewDistinctAttributes: http.StatusBadRequest,
|
||||
ErrEmailAlreadySent: http.StatusBadRequest,
|
||||
ErrClaimExpOrNil: http.StatusForbidden,
|
||||
ErrInvalidJwt: http.StatusForbidden,
|
||||
ErrInvalidKeypadDimensions: http.StatusBadRequest,
|
||||
ErrUserAlreadyExists: http.StatusBadRequest,
|
||||
ErrSignupSessionDNE: http.StatusBadRequest,
|
||||
ErrUserForCustomerDNE: http.StatusBadRequest,
|
||||
ErrRefreshTokenInvalid: http.StatusForbidden,
|
||||
ErrCustomerDne: http.StatusBadRequest,
|
||||
ErrSvgDne: http.StatusBadRequest,
|
||||
ErrStoppingDatabase: http.StatusInternalServerError,
|
||||
ErrSqliteTx: http.StatusInternalServerError,
|
||||
ErrEmptySvgTable: http.StatusInternalServerError,
|
||||
ErrKeyIndexOutOfRange: http.StatusBadRequest,
|
||||
ErrAttributeIndexOutOfRange: http.StatusInternalServerError,
|
||||
ErrInternalValidKeyEntry: http.StatusInternalServerError,
|
||||
ErrUserMaskTooLong: http.StatusInternalServerError,
|
||||
ErrInterfaceNotDispersible: http.StatusInternalServerError,
|
||||
ErrIncompleteUserSignupSession: http.StatusBadRequest,
|
||||
ErrSetConfirmSignupMismatch: http.StatusBadRequest,
|
||||
ErrKeypadIsNotDispersible: http.StatusInternalServerError,
|
||||
}
|
||||
@@ -1,11 +1,8 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"go-nkode/hashset"
|
||||
py "go-nkode/py-builtin"
|
||||
"go-nkode/util"
|
||||
)
|
||||
|
||||
@@ -31,16 +28,12 @@ func NewCustomer(nkodePolicy NKodePolicy) (*Customer, error) {
|
||||
|
||||
func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error {
|
||||
nkodeLen := len(passcodeAttrIdx)
|
||||
if nkodeLen < c.NKodePolicy.MinNkodeLen {
|
||||
return errors.New(fmt.Sprintf("NKode length %d is too short. Minimum nKode length is %d", nkodeLen, c.NKodePolicy.MinNkodeLen))
|
||||
if nkodeLen < c.NKodePolicy.MinNkodeLen || nkodeLen > c.NKodePolicy.MaxNkodeLen {
|
||||
return ErrInvalidNKodeLength
|
||||
}
|
||||
|
||||
validIdx := py.All[int](passcodeAttrIdx, func(i int) bool {
|
||||
return i >= 0 && i < kp.TotalAttrs()
|
||||
})
|
||||
|
||||
if !validIdx {
|
||||
return errors.New(fmt.Sprintf("One or more idx out of range 0-%d in IsValidNKode", kp.TotalAttrs()-1))
|
||||
if validIdx := kp.ValidateAttributeIndices(passcodeAttrIdx); !validIdx {
|
||||
return ErrInvalidNKodeIdx
|
||||
}
|
||||
passcodeSetVals := make(hashset.Set[uint64])
|
||||
passcodeAttrVals := make(hashset.Set[uint64])
|
||||
@@ -59,33 +52,32 @@ func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error
|
||||
}
|
||||
|
||||
if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets {
|
||||
return errors.New(fmt.Sprintf("passcode has two few distinct sets min %d, has %d", c.NKodePolicy.DistinctSets, passcodeSetVals.Size()))
|
||||
return ErrTooFewDistinctSet
|
||||
}
|
||||
|
||||
if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes {
|
||||
return errors.New(fmt.Sprintf("passcode has two few distinct attributes min %d, has %d", c.NKodePolicy.DistinctAttributes, passcodeAttrVals.Size()))
|
||||
return ErrTooFewDistinctAttributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Customer) RenewKeys() ([]uint64, []uint64) {
|
||||
func (c *Customer) RenewKeys() ([]uint64, []uint64, error) {
|
||||
oldAttrs := make([]uint64, len(c.Attributes.AttrVals))
|
||||
oldSets := make([]uint64, len(c.Attributes.SetVals))
|
||||
|
||||
copy(oldAttrs, c.Attributes.AttrVals)
|
||||
copy(oldSets, c.Attributes.SetVals)
|
||||
|
||||
err := c.Attributes.Renew()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
if err := c.Attributes.Renew(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
attrsXor, err := util.XorLists(oldAttrs, c.Attributes.AttrVals)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
setXor, err := util.XorLists(oldSets, c.Attributes.SetVals)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
return setXor, attrsXor
|
||||
return setXor, attrsXor, nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
)
|
||||
|
||||
type CustomerAttributes struct {
|
||||
@@ -12,13 +11,15 @@ type CustomerAttributes struct {
|
||||
}
|
||||
|
||||
func NewCustomerAttributes() (*CustomerAttributes, error) {
|
||||
attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
|
||||
if errAttr != nil {
|
||||
return nil, errAttr
|
||||
attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
|
||||
if err != nil {
|
||||
log.Print("unable to generate attribute vals: ", err)
|
||||
return nil, err
|
||||
}
|
||||
setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
|
||||
if errSet != nil {
|
||||
return nil, errSet
|
||||
setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
|
||||
if err != nil {
|
||||
log.Print("unable to generate set vals: ", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
customerAttrs := CustomerAttributes{
|
||||
@@ -36,37 +37,33 @@ func NewCustomerAttributesFromBytes(attrBytes []byte, setBytes []byte) CustomerA
|
||||
}
|
||||
|
||||
func (c *CustomerAttributes) Renew() error {
|
||||
attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
|
||||
if errAttr != nil {
|
||||
return errAttr
|
||||
attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
|
||||
if errSet != nil {
|
||||
return errSet
|
||||
setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.AttrVals = attrVals
|
||||
c.SetVals = setVals
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) int {
|
||||
func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) (int, error) {
|
||||
// TODO: should this be mapped instead?
|
||||
return util.IndexOf[uint64](c.AttrVals, attrVal)
|
||||
}
|
||||
|
||||
func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) {
|
||||
// TODO: should this be mapped instead?
|
||||
idx := util.IndexOf[uint64](c.SetVals, setVal)
|
||||
if idx == -1 {
|
||||
return -1, errors.New(fmt.Sprintf("Set Val %d is invalid", setVal))
|
||||
}
|
||||
return idx, nil
|
||||
return util.IndexOf[uint64](c.SetVals, setVal)
|
||||
}
|
||||
|
||||
func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad KeypadDimension) (uint64, error) {
|
||||
indexOfAttr := c.IndexOfAttr(attrVal)
|
||||
if indexOfAttr == -1 {
|
||||
return 0, errors.New(fmt.Sprintf("No attribute %d", attrVal))
|
||||
indexOfAttr, err := c.IndexOfAttr(attrVal)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
setIdx := indexOfAttr % userKeypad.AttrsPerKey
|
||||
return c.SetVals[setIdx], nil
|
||||
|
||||
@@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
@@ -51,9 +50,9 @@ func NewSESClient() SESClient {
|
||||
}
|
||||
|
||||
func (s *SESClient) SendEmail(email Email) error {
|
||||
|
||||
if _, exists := s.ResetCache.Get(email.Recipient); exists {
|
||||
return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject)
|
||||
log.Printf("email already sent to %s with subject %s", email.Recipient, email.Subject)
|
||||
return ErrEmailAlreadySent
|
||||
}
|
||||
|
||||
// Load AWS configuration
|
||||
@@ -61,7 +60,7 @@ func (s *SESClient) SendEmail(email Email) error {
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("unable to load SDK config, %v", err)
|
||||
log.Print(errMsg)
|
||||
return errors.New(errMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an SES client
|
||||
@@ -93,7 +92,9 @@ func (s *SESClient) SendEmail(email Email) error {
|
||||
resp, err := sesClient.SendEmail(context.TODO(), input)
|
||||
if err != nil {
|
||||
s.ResetCache.Delete(email.Recipient)
|
||||
return fmt.Errorf("failed to send email, %v", err)
|
||||
errMsg := fmt.Sprintf("failed to send email, %v", err)
|
||||
log.Print(errMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
// Output the message ID of the sent email
|
||||
|
||||
@@ -89,9 +89,11 @@ func (db *InMemoryDb) Renew(id CustomerId) error {
|
||||
if !exists {
|
||||
return errors.New(fmt.Sprintf("customer %s does not exist", id))
|
||||
}
|
||||
setXor, attrsXor := customer.RenewKeys()
|
||||
setXor, attrsXor, err := customer.RenewKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.Customers[id] = customer
|
||||
var err error
|
||||
for _, user := range db.Users {
|
||||
if user.CustomerId == id {
|
||||
err = user.RenewKeys(setXor, attrsXor)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
@@ -78,42 +76,37 @@ func EncodeAndSignClaims(claims jwt.Claims) (string, error) {
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing refresh token: %w", err)
|
||||
}
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to parse claims")
|
||||
}
|
||||
return claims, nil
|
||||
func ParseRegisteredClaimToken(token string) (*jwt.RegisteredClaims, error) {
|
||||
return parseJwt[*jwt.RegisteredClaims](token, &jwt.RegisteredClaims{})
|
||||
}
|
||||
|
||||
func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) {
|
||||
return parseJwt[*ResetNKodeClaims](resetNKodeToken, &ResetNKodeClaims{})
|
||||
}
|
||||
|
||||
func parseJwt[T *ResetNKodeClaims | *jwt.RegisteredClaims](tokenStr string, claim jwt.Claims) (T, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, claim, func(token *jwt.Token) (interface{}, error) {
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing refresh token: %w", err)
|
||||
log.Printf("error parsing refresh token: %v", err)
|
||||
return nil, ErrInvalidJwt
|
||||
}
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
claims, ok := token.Claims.(T)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to parse claims")
|
||||
return nil, ErrInvalidJwt
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func ClaimExpired(claims jwt.RegisteredClaims) error {
|
||||
if claims.ExpiresAt == nil {
|
||||
return errors.New("claim exp is nil")
|
||||
return ErrClaimExpOrNil
|
||||
}
|
||||
if claims.ExpiresAt.Time.After(time.Now()) {
|
||||
return nil
|
||||
}
|
||||
return errors.New("claim expired")
|
||||
return ErrClaimExpOrNil
|
||||
}
|
||||
|
||||
func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error) {
|
||||
@@ -127,17 +120,3 @@ func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error)
|
||||
}
|
||||
return EncodeAndSignClaims(resetClaims)
|
||||
}
|
||||
|
||||
func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(resetNKodeToken, &ResetNKodeClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing refresh token: %w", err)
|
||||
}
|
||||
claims, ok := token.Claims.(*ResetNKodeClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to parse claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ func TestJwtClaims(t *testing.T) {
|
||||
customerId := CustomerId(uuid.New())
|
||||
authTokens, err := NewAuthenticationTokens(email, customerId)
|
||||
assert.NoError(t, err)
|
||||
accessToken, err := ParseAccessToken(authTokens.AccessToken)
|
||||
accessToken, err := ParseRegisteredClaimToken(authTokens.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken.Subject, email)
|
||||
assert.NoError(t, ClaimExpired(*accessToken))
|
||||
refreshToken, err := ParseRefreshToken(authTokens.RefreshToken)
|
||||
refreshToken, err := ParseRegisteredClaimToken(authTokens.RefreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, refreshToken.Subject, email)
|
||||
assert.NoError(t, ClaimExpired(*refreshToken))
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package core
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
py "go-nkode/py-builtin"
|
||||
)
|
||||
|
||||
type KeypadDimension struct {
|
||||
AttrsPerKey int `json:"attrs_per_key"`
|
||||
@@ -17,11 +19,23 @@ func (kp *KeypadDimension) IsDispersable() bool {
|
||||
|
||||
func (kp *KeypadDimension) IsValidKeypadDimension() error {
|
||||
if KeypadMin.AttrsPerKey > kp.AttrsPerKey || KeypadMax.AttrsPerKey < kp.AttrsPerKey || KeypadMin.NumbOfKeys > kp.NumbOfKeys || KeypadMax.NumbOfKeys < kp.NumbOfKeys {
|
||||
return errors.New("keypad dimensions out of range")
|
||||
return ErrInvalidKeypadDimensions
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (kp *KeypadDimension) ValidKeySelections(selectedKeys []int) bool {
|
||||
return py.All[int](selectedKeys, func(idx int) bool {
|
||||
return 0 <= idx && idx < kp.NumbOfKeys
|
||||
})
|
||||
}
|
||||
|
||||
func (kp *KeypadDimension) ValidateAttributeIndices(attrIndicies []int) bool {
|
||||
return py.All[int](attrIndicies, func(i int) bool {
|
||||
return i >= 0 && i < kp.TotalAttrs()
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
KeypadMax = KeypadDimension{
|
||||
AttrsPerKey: 16,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,8 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId
|
||||
return nil, err
|
||||
}
|
||||
if user != nil && !reset {
|
||||
return nil, fmt.Errorf("user %s already exists", string(userEmail))
|
||||
log.Printf("user %s already exists", string(userEmail))
|
||||
return nil, ErrUserAlreadyExists
|
||||
}
|
||||
svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp)
|
||||
if err != nil {
|
||||
@@ -74,7 +75,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec
|
||||
}
|
||||
session, exists := n.SignupSessions[sessionId]
|
||||
if !exists {
|
||||
return nil, errors.New(fmt.Sprintf("session id does not exist %s", sessionId))
|
||||
log.Printf("session id does not exist %s", sessionId)
|
||||
return nil, ErrSignupSessionDNE
|
||||
}
|
||||
confirmInterface, err := session.SetUserNKode(keySelection)
|
||||
if err != nil {
|
||||
@@ -87,7 +89,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec
|
||||
func (n *NKodeAPI) ConfirmNKode(customerId CustomerId, sessionId SessionId, keySelection KeySelection) error {
|
||||
session, exists := n.SignupSessions[sessionId]
|
||||
if !exists {
|
||||
return errors.New(fmt.Sprintf("session id does not exist %s", sessionId))
|
||||
log.Printf("session id does not exist %s", sessionId)
|
||||
return ErrSignupSessionDNE
|
||||
}
|
||||
customer, err := n.Db.GetCustomer(customerId)
|
||||
if err != nil {
|
||||
@@ -120,7 +123,8 @@ func (n *NKodeAPI) GetLoginInterface(userEmail UserEmail, customerId CustomerId)
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId))
|
||||
log.Printf("user %s for customer %s dne", userEmail, customerId)
|
||||
return nil, ErrUserForCustomerDNE
|
||||
}
|
||||
err = user.Interface.PartialInterfaceShuffle()
|
||||
if err != nil {
|
||||
@@ -153,7 +157,8 @@ func (n *NKodeAPI) Login(customerId CustomerId, userEmail UserEmail, keySelectio
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId))
|
||||
log.Printf("user %s for customer %s dne", userEmail, customerId)
|
||||
return nil, ErrUserForCustomerDNE
|
||||
}
|
||||
passcode, err := ValidKeyEntry(*user, *customer, keySelection)
|
||||
if err != nil {
|
||||
@@ -195,12 +200,13 @@ func (n *NKodeAPI) RefreshToken(userEmail UserEmail, customerId CustomerId, refr
|
||||
return "", err
|
||||
}
|
||||
if user == nil {
|
||||
return "", errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId))
|
||||
log.Printf("user %s for customer %s dne", userEmail, customerId)
|
||||
return "", ErrUserForCustomerDNE
|
||||
}
|
||||
if user.RefreshToken != refreshToken {
|
||||
return "", errors.New("refresh token is invalid")
|
||||
return "", ErrRefreshTokenInvalid
|
||||
}
|
||||
refreshClaims, err := ParseRefreshToken(refreshToken)
|
||||
refreshClaims, err := ParseRegisteredClaimToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -222,7 +228,7 @@ func (n *NKodeAPI) ResetNKode(userEmail UserEmail, customerId CustomerId) error
|
||||
|
||||
nkodeResetJwt, err := ResetNKodeToken(userEmail, customerId)
|
||||
if err != nil {
|
||||
return errors.New(fmt.Sprintf("unable to load SDK config, %v", err))
|
||||
return err
|
||||
}
|
||||
frontendHost := os.Getenv("FRONTEND_HOST")
|
||||
if frontendHost == "" {
|
||||
|
||||
@@ -26,6 +26,12 @@ const (
|
||||
ResetNKode = "/reset-nkode"
|
||||
)
|
||||
|
||||
const (
|
||||
malformedCustomerId = "malformed customer id"
|
||||
malformedUserEmail = "malformed user email"
|
||||
malformedSessionId = "malformed session id"
|
||||
)
|
||||
|
||||
func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case CreateNewCustomer:
|
||||
@@ -56,226 +62,148 @@ func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print("create new customer")
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var customerPost NewCustomerPost
|
||||
err := decodeJson(w, r, &customerPost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &customerPost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := h.Api.CreateNewCustomer(customerPost.NKodePolicy, nil)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
respBody := CreateNewCustomerResp{
|
||||
CustomerId: uuid.UUID(*customerId).String(),
|
||||
}
|
||||
respBytes, err := json.Marshal(respBody)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, respBody)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print("signup/reset interface")
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
|
||||
var signupResetPost GenerateSignupRestInterfacePost
|
||||
err := decodeJson(w, r, &signupResetPost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &signupResetPost); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
kp := KeypadDimension{
|
||||
AttrsPerKey: signupResetPost.AttrsPerKey,
|
||||
NumbOfKeys: signupResetPost.NumbOfKeys,
|
||||
}
|
||||
err = kp.IsValidKeypadDimension()
|
||||
if err != nil {
|
||||
keypadSizeOutOfRange(w)
|
||||
log.Println(err)
|
||||
if err := kp.IsValidKeypadDimension(); err != nil {
|
||||
badRequest(w, "invalid keypad dimensions")
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(signupResetPost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
userEmail, err := ParseEmail(signupResetPost.UserEmail)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedUserEmail)
|
||||
return
|
||||
}
|
||||
resp, err := h.Api.GenerateSignupResetInterface(userEmail, CustomerId(customerId), kp, signupResetPost.Reset)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
respBytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, resp)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print("set nkode")
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var setNKodePost SetNKodePost
|
||||
err := decodeJson(w, r, &setNKodePost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &setNKodePost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(setNKodePost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
sessionId, err := uuid.Parse(setNKodePost.SessionId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedSessionId)
|
||||
return
|
||||
}
|
||||
confirmInterface, err := h.Api.SetNKode(CustomerId(customerId), SessionId(sessionId), setNKodePost.KeySelection)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
respBody := SetNKodeResp{UserInterface: confirmInterface}
|
||||
|
||||
respBytes, err := json.Marshal(respBody)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, respBody)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Print("confirm nkode")
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
|
||||
var confirmNKodePost ConfirmNKodePost
|
||||
err := decodeJson(w, r, &confirmNKodePost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &confirmNKodePost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(confirmNKodePost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
sessionId, err := uuid.Parse(confirmNKodePost.SessionId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedSessionId)
|
||||
return
|
||||
}
|
||||
err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection); err != nil {
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var loginInterfacePost GetLoginInterfacePost
|
||||
err := decodeJson(w, r, &loginInterfacePost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &loginInterfacePost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(loginInterfacePost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
userEmail, err := ParseEmail(loginInterfacePost.UserEmail)
|
||||
if err != nil {
|
||||
badRequest(w, malformedUserEmail)
|
||||
}
|
||||
loginInterface, err := h.Api.GetLoginInterface(userEmail, CustomerId(customerId))
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(loginInterface)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
marshalAndWriteBytes(w, loginInterface)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -284,40 +212,26 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
var loginPost LoginPost
|
||||
err := decodeJson(w, r, &loginPost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &loginPost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(loginPost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
userEmail, err := ParseEmail(loginPost.UserEmail)
|
||||
if err != nil {
|
||||
badRequest(w, malformedUserEmail)
|
||||
return
|
||||
}
|
||||
jwtTokens, err := h.Api.Login(CustomerId(customerId), userEmail, loginPost.KeySelection)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(jwtTokens)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, jwtTokens)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -326,23 +240,16 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
var renewAttributesPost RenewAttributesPost
|
||||
err := decodeJson(w, r, &renewAttributesPost)
|
||||
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err := decodeJson(w, r, &renewAttributesPost); err != nil {
|
||||
return
|
||||
}
|
||||
customerId, err := uuid.Parse(renewAttributesPost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
err = h.Api.RenewAttributes(CustomerId(customerId))
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
if err = h.Api.RenewAttributes(CustomerId(customerId)); err != nil {
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -355,26 +262,11 @@ func (h *NKodeHandler) RandomSvgInterfaceHandler(w http.ResponseWriter, r *http.
|
||||
}
|
||||
svgs, err := h.Api.RandomSvgInterface()
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
handleError(w, err)
|
||||
return
|
||||
}
|
||||
respBody := RandomSvgInterfaceResp{Svgs: svgs}
|
||||
respBytes, err := json.Marshal(respBody)
|
||||
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, respBody)
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -383,74 +275,54 @@ func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
refreshToken, err := getBearerToken(r)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
forbidden(w)
|
||||
return
|
||||
}
|
||||
refreshClaims, err := ParseRefreshToken(refreshToken)
|
||||
refreshClaims, err := ParseRegisteredClaimToken(refreshToken)
|
||||
customerId, err := uuid.Parse(refreshClaims.Issuer)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
userEmail, err := ParseEmail(refreshClaims.Subject)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
badRequest(w, malformedUserEmail)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
accessToken, err := h.Api.RefreshToken(userEmail, CustomerId(customerId), refreshToken)
|
||||
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
handleError(w, err)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
respBytes, err := json.Marshal(RefreshTokenResp{AccessToken: accessToken})
|
||||
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
marshalAndWriteBytes(w, RefreshTokenResp{AccessToken: accessToken})
|
||||
}
|
||||
|
||||
func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
methodNotAllowed(w)
|
||||
}
|
||||
log.Print("Resetting email")
|
||||
var resetNKodePost ResetNKodePost
|
||||
err := decodeJson(w, r, &resetNKodePost)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println("error decoding reset nkode post: ", err)
|
||||
if err := decodeJson(w, r, &resetNKodePost); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
customerId, err := uuid.Parse(resetNKodePost.CustomerId)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedCustomerId)
|
||||
return
|
||||
}
|
||||
|
||||
userEmail, err := ParseEmail(resetNKodePost.UserEmail)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
log.Println(err)
|
||||
badRequest(w, malformedUserEmail)
|
||||
return
|
||||
}
|
||||
err = h.Api.ResetNKode(userEmail, CustomerId(customerId))
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
|
||||
if err = h.Api.ResetNKode(userEmail, CustomerId(customerId)); err != nil {
|
||||
internalServerError(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
@@ -459,43 +331,90 @@ func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func decodeJson(w http.ResponseWriter, r *http.Request, post any) error {
|
||||
if r.Body == nil {
|
||||
invalidJson(w)
|
||||
return errors.New("invalid json")
|
||||
badRequest(w, "unable to parse body")
|
||||
log.Println("error decoding json: body is nil")
|
||||
return errors.New("body is nil")
|
||||
}
|
||||
err := json.NewDecoder(r.Body).Decode(&post)
|
||||
if err != nil {
|
||||
internalServerErrorHandler(w)
|
||||
badRequest(w, "unable to parse body")
|
||||
log.Println("error decoding json: ", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func internalServerErrorHandler(w http.ResponseWriter) {
|
||||
func internalServerError(w http.ResponseWriter) {
|
||||
log.Print("500 internal server error")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("500 Internal Server Error"))
|
||||
}
|
||||
|
||||
func badRequest(w http.ResponseWriter, msg string) {
|
||||
log.Print("bad request: ", msg)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
if msg == "" {
|
||||
w.Write([]byte("400 Bad Request"))
|
||||
} else {
|
||||
w.Write([]byte(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func methodNotAllowed(w http.ResponseWriter) {
|
||||
log.Print("405 method not allowed")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte("405 method not allowed"))
|
||||
}
|
||||
|
||||
func keypadSizeOutOfRange(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("invalid keypad dimensions"))
|
||||
func forbidden(w http.ResponseWriter) {
|
||||
log.Print("403 forbidden")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("403 Forbidden"))
|
||||
}
|
||||
|
||||
func invalidJson(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("invalid json"))
|
||||
func handleError(w http.ResponseWriter, err error) {
|
||||
log.Print("handling error: ", err)
|
||||
statusCode, exists := HttpErrMap[err]
|
||||
if !exists {
|
||||
internalServerError(w)
|
||||
return
|
||||
}
|
||||
switch statusCode {
|
||||
case http.StatusBadRequest:
|
||||
badRequest(w, err.Error())
|
||||
case http.StatusForbidden:
|
||||
forbidden(w)
|
||||
case http.StatusInternalServerError:
|
||||
internalServerError(w)
|
||||
default:
|
||||
log.Print("unknown error: ", err)
|
||||
internalServerError(w)
|
||||
}
|
||||
}
|
||||
|
||||
func getBearerToken(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
// Check if the Authorization header is present and starts with "Bearer "
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return "", errors.New("authorization header missing or invalid")
|
||||
return "", errors.New("forbidden")
|
||||
}
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func marshalAndWriteBytes(w http.ResponseWriter, data any) {
|
||||
respBytes, err := json.Marshal(data)
|
||||
|
||||
if err != nil {
|
||||
internalServerError(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
_, err = w.Write(respBytes)
|
||||
|
||||
if err != nil {
|
||||
internalServerError(w)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type NKodePolicy struct {
|
||||
MaxNkodeLen int `json:"max_nkode_len"`
|
||||
MinNkodeLen int `json:"min_nkode_len"`
|
||||
@@ -24,12 +20,10 @@ func NewDefaultNKodePolicy() NKodePolicy {
|
||||
}
|
||||
}
|
||||
|
||||
var InvalidNKodeLen = errors.New("invalid nkode length")
|
||||
|
||||
func (p *NKodePolicy) ValidLength(nkodeLen int) error {
|
||||
|
||||
if nkodeLen < p.MinNkodeLen || nkodeLen > p.MaxNkodeLen {
|
||||
return InvalidNKodeLen
|
||||
return ErrInvalidNKodeLength
|
||||
}
|
||||
// TODO: validate Max > Min
|
||||
// Validate lockout
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
)
|
||||
|
||||
type NKodeSecrets struct {
|
||||
JwtSecret []byte `json:"jwt_secret"`
|
||||
}
|
||||
|
||||
func ReadSecrets(filePath string) (NKodeSecrets, error) {
|
||||
// Initialize an empty NKodeSecrets struct
|
||||
var secrets NKodeSecrets
|
||||
|
||||
// Read the contents of the file
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return secrets, fmt.Errorf("error reading secrets file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal JSON data into the NKodeSecrets struct
|
||||
err = json.Unmarshal(data, &secrets)
|
||||
if err != nil {
|
||||
return secrets, fmt.Errorf("error unmarshaling secrets: %w", err)
|
||||
}
|
||||
|
||||
return secrets, nil
|
||||
}
|
||||
|
||||
func GetJwtSecret(filePath string) []byte {
|
||||
secrets, err := ReadSecrets(filePath)
|
||||
if err != nil {
|
||||
log.Fatal("can't read secrets: ", err)
|
||||
}
|
||||
if secrets.JwtSecret == nil {
|
||||
log.Fatal("wt secret is nil")
|
||||
}
|
||||
return secrets.JwtSecret
|
||||
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
||||
@@ -125,7 +124,10 @@ func (d *SqliteDB) Renew(id CustomerId) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setXor, attrXor := customer.RenewKeys()
|
||||
setXor, attrXor, err := customer.RenewKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
|
||||
// TODO: replace with tx
|
||||
renewQuery := `
|
||||
@@ -208,9 +210,13 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
|
||||
}()
|
||||
selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?`
|
||||
rows, err := tx.Query(selectCustomer, uuid.UUID(id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err()))
|
||||
log.Printf("no new row for customer %s with err %s", id, rows.Err())
|
||||
return nil, ErrCustomerDne
|
||||
}
|
||||
|
||||
var maxNKodeLen int
|
||||
@@ -225,10 +231,6 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id))
|
||||
}
|
||||
customer := Customer{
|
||||
Id: id,
|
||||
NKodePolicy: NKodePolicy{
|
||||
@@ -241,9 +243,8 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
|
||||
},
|
||||
Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues),
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read customer won't commit %w", err)
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &customer, nil
|
||||
}
|
||||
@@ -278,9 +279,6 @@ WHERE user.username = ? AND user.customer_id = ?
|
||||
var svgIdInterface []byte
|
||||
|
||||
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
|
||||
if rows.Next() {
|
||||
return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId))
|
||||
}
|
||||
|
||||
userId, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
@@ -324,8 +322,7 @@ WHERE user.username = ? AND user.customer_id = ?
|
||||
}
|
||||
user.Interface.Kp = &user.Kp
|
||||
user.CipherKeys.Kp = &user.Kp
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
@@ -360,15 +357,14 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
if !rows.Next() {
|
||||
return nil, errors.New(fmt.Sprintf("id not found: %d", id))
|
||||
log.Printf("id not found: %d", id)
|
||||
return nil, ErrSvgDne
|
||||
}
|
||||
err = rows.Scan(&svgs[idx])
|
||||
if err != nil {
|
||||
if err = rows.Scan(&svgs[idx]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return svgs, nil
|
||||
@@ -383,16 +379,14 @@ func (d *SqliteDB) writeToDb(query string, args []any) error {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write won't roll back %+v", err))
|
||||
log.Fatalf("fatal error: write won't roll back %+v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
_, err = tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
if _, err = tx.Exec(query, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -400,7 +394,7 @@ func (d *SqliteDB) writeToDb(query string, args []any) error {
|
||||
|
||||
func (d *SqliteDB) addWriteTx(query string, args []any) error {
|
||||
if d.stop {
|
||||
return errors.New("stopping database")
|
||||
return ErrStoppingDatabase
|
||||
}
|
||||
errChan := make(chan error)
|
||||
writeTx := WriteTx{
|
||||
@@ -416,31 +410,37 @@ func (d *SqliteDB) addWriteTx(query string, args []any) error {
|
||||
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Print(err)
|
||||
return nil, ErrSqliteTx
|
||||
}
|
||||
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Print(err)
|
||||
return nil, ErrSqliteTx
|
||||
}
|
||||
var tableLen int
|
||||
if !rows.Next() {
|
||||
return nil, errors.New("empty svg_icon table")
|
||||
return nil, ErrEmptySvgTable
|
||||
}
|
||||
|
||||
if err = rows.Scan(&tableLen); err != nil {
|
||||
return nil, err
|
||||
log.Print(err)
|
||||
return nil, ErrSqliteTx
|
||||
}
|
||||
|
||||
perm, err := util.RandomPermutation(tableLen)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for idx := range perm {
|
||||
perm[idx] += 1
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
log.Print(err)
|
||||
return nil, ErrSqliteTx
|
||||
}
|
||||
|
||||
return perm[:count], nil
|
||||
}
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize KeypadDimension) ([]int, error) {
|
||||
selectedKeys := make([]int, len(passcodeIdxs))
|
||||
for idx := range passcodeIdxs {
|
||||
attrIdx := util.IndexOf[int](interfaceUser, passcodeIdxs[idx])
|
||||
if attrIdx == -1 {
|
||||
return nil, errors.New(fmt.Sprintf("index: %d out of range 0-%d", passcodeIdxs[idx], keypadSize.TotalAttrs()-1))
|
||||
attrIdx, err := util.IndexOf[int](interfaceUser, passcodeIdxs[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyNumb := attrIdx / keypadSize.AttrsPerKey
|
||||
if keyNumb >= keypadSize.NumbOfKeys {
|
||||
|
||||
29
core/user.go
29
core/user.go
@@ -1,10 +1,9 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/google/uuid"
|
||||
"go-nkode/py-builtin"
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
@@ -61,31 +60,27 @@ func (u *User) GetLoginInterface() ([]int, error) {
|
||||
return u.Interface.IdxInterface, nil
|
||||
}
|
||||
|
||||
var KeyIndexOutOfRange = errors.New("one or more keys is out of range")
|
||||
|
||||
func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, error) {
|
||||
validKeys := py_builtin.All[int](selectedKeys, func(idx int) bool {
|
||||
return 0 <= idx && idx < user.Kp.NumbOfKeys
|
||||
})
|
||||
if !validKeys {
|
||||
panic(KeyIndexOutOfRange)
|
||||
if validKeys := user.Kp.ValidKeySelections(selectedKeys); !validKeys {
|
||||
|
||||
return nil, ErrKeyIndexOutOfRange
|
||||
}
|
||||
|
||||
var err error
|
||||
passcodeLen := len(selectedKeys)
|
||||
err = customer.NKodePolicy.ValidLength(passcodeLen)
|
||||
if err != nil {
|
||||
if err := customer.NKodePolicy.ValidLength(passcodeLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setVals, err := customer.Attributes.SetValsForKp(user.Kp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Printf("fatal error in validate key entry;invalid user keypad dimensions for user %s with error %v", user.Email, err)
|
||||
return nil, ErrInternalValidKeyEntry
|
||||
}
|
||||
|
||||
passcodeSetVals, err := user.DecipherMask(setVals, passcodeLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Printf("fatal error in validate key entry;something when wrong deciphering mask;user email %s; error %v", user.Email, err)
|
||||
return nil, ErrInternalValidKeyEntry
|
||||
}
|
||||
presumedAttrIdxVals := make([]int, passcodeLen)
|
||||
|
||||
@@ -93,11 +88,13 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err
|
||||
keyNumb := selectedKeys[idx]
|
||||
setIdx, err := customer.Attributes.IndexOfSet(passcodeSetVals[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Printf("fatal error in validate key entry;something when wrong getting the IndexOfSet;user email %s; error %v", user.Email, err)
|
||||
return nil, ErrInternalValidKeyEntry
|
||||
}
|
||||
selectedAttrIdx, err := user.Interface.GetAttrIdxByKeyNumbSetIdx(setIdx, keyNumb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Printf("fatal error in validate key entry;something when wrong getting the GetAttrIdxByKeyNumbSetIdx;user email %s; error %v", user.Email, err)
|
||||
return nil, ErrInternalValidKeyEntry
|
||||
}
|
||||
presumedAttrIdxVals[idx] = selectedAttrIdx
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"go-nkode/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
@@ -49,7 +48,7 @@ func NewUserCipherKeys(kp *KeypadDimension, setVals []uint64, maxNKodeLen int) (
|
||||
|
||||
func (u *UserCipherKeys) PadUserMask(userMask []uint64, setVals []uint64) ([]uint64, error) {
|
||||
if len(userMask) > u.MaxNKodeLen {
|
||||
return nil, errors.New("user mask length exceeds max nkode length")
|
||||
return nil, ErrUserMaskTooLong
|
||||
}
|
||||
paddedUserMask := make([]uint64, len(userMask))
|
||||
copy(paddedUserMask, userMask)
|
||||
@@ -153,7 +152,10 @@ func (u *UserCipherKeys) DecipherMask(mask string, setVals []uint64, passcodeLen
|
||||
|
||||
passcodeSet := make([]uint64, passcodeLen)
|
||||
for idx, setCipher := range decipheredMask[:passcodeLen] {
|
||||
setIdx := util.IndexOf(setKeyRandComponent, setCipher)
|
||||
setIdx, err := util.IndexOf(setKeyRandComponent, setCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passcodeSet[idx] = setVals[setIdx]
|
||||
}
|
||||
return passcodeSet, nil
|
||||
@@ -175,6 +177,9 @@ func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs Cust
|
||||
}
|
||||
}
|
||||
mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.Kp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encipheredCode := EncipheredNKode{
|
||||
Code: code,
|
||||
Mask: mask,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go-nkode/hashset"
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
)
|
||||
|
||||
type UserInterface struct {
|
||||
@@ -70,7 +69,7 @@ func (u *UserInterface) SetViewMatrix() ([][]int, error) {
|
||||
|
||||
func (u *UserInterface) DisperseInterface() error {
|
||||
if !u.Kp.IsDispersable() {
|
||||
panic("interface is not dispersable")
|
||||
return ErrInterfaceNotDispersible
|
||||
}
|
||||
|
||||
err := u.shuffleKeys()
|
||||
@@ -180,11 +179,13 @@ func (u *UserInterface) PartialInterfaceShuffle() error {
|
||||
|
||||
func (u *UserInterface) GetAttrIdxByKeyNumbSetIdx(setIdx int, keyNumb int) (int, error) {
|
||||
if keyNumb < 0 || u.Kp.NumbOfKeys <= keyNumb {
|
||||
return -1, errors.New(fmt.Sprintf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys))
|
||||
log.Printf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys)
|
||||
return -1, ErrKeyIndexOutOfRange
|
||||
}
|
||||
|
||||
if setIdx < 0 || u.Kp.AttrsPerKey <= setIdx {
|
||||
return -1, errors.New(fmt.Sprintf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey))
|
||||
log.Printf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey)
|
||||
return -1, ErrAttributeIndexOutOfRange
|
||||
}
|
||||
keypadView, err := u.InterfaceMatrix()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"go-nkode/hashset"
|
||||
py "go-nkode/py-builtin"
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
)
|
||||
|
||||
type UserSignSession struct {
|
||||
@@ -52,27 +51,33 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e
|
||||
})
|
||||
|
||||
if !validEntry {
|
||||
return nil, errors.New(fmt.Sprintf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys))
|
||||
log.Printf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys)
|
||||
return nil, ErrKeyIndexOutOfRange
|
||||
}
|
||||
|
||||
if s.SetIdxInterface == nil {
|
||||
return nil, errors.New("signup session set interface is nil")
|
||||
log.Print("signup session set interface is nil")
|
||||
return nil, ErrIncompleteUserSignupSession
|
||||
}
|
||||
|
||||
if s.ConfirmIdxInterface == nil {
|
||||
return nil, errors.New("signup session confirm interface is nil")
|
||||
log.Print("signup session confirm interface is nil")
|
||||
return nil, ErrIncompleteUserSignupSession
|
||||
}
|
||||
|
||||
if s.SetKeySelection == nil {
|
||||
return nil, errors.New("signup session set key entry is nil")
|
||||
log.Print("signup session set key entry is nil")
|
||||
return nil, ErrIncompleteUserSignupSession
|
||||
}
|
||||
|
||||
if s.UserEmail == "" {
|
||||
return nil, errors.New("signup session username is nil")
|
||||
log.Print("signup session username is nil")
|
||||
return nil, ErrIncompleteUserSignupSession
|
||||
}
|
||||
|
||||
if len(confirmKeyEntry) != len(s.SetKeySelection) {
|
||||
return nil, errors.New(fmt.Sprintf("confirm and set key entry lenght mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection)))
|
||||
log.Printf("confirm and set key entry length mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection))
|
||||
return nil, ErrSetConfirmSignupMismatch
|
||||
}
|
||||
|
||||
passcodeLen := len(confirmKeyEntry)
|
||||
@@ -88,10 +93,12 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e
|
||||
confirmKey := hashset.NewSetFromSlice[int](confirmKeyVals[idx])
|
||||
intersection := setKey.Intersect(confirmKey)
|
||||
if intersection.Size() < 1 {
|
||||
return nil, errors.New(fmt.Sprintf("set and confirm do not intersect at index %d", idx))
|
||||
log.Printf("set and confirm do not intersect at index %d", idx)
|
||||
return nil, ErrSetConfirmSignupMismatch
|
||||
}
|
||||
if intersection.Size() > 1 {
|
||||
return nil, errors.New(fmt.Sprintf("set and confirm intersect at more than one point at index %d", idx))
|
||||
log.Printf("set and confirm intersect at more than one point at index %d", idx)
|
||||
return nil, ErrSetConfirmSignupMismatch
|
||||
}
|
||||
intersectionSlice := intersection.ToSlice()
|
||||
passcode[idx] = intersectionSlice[0]
|
||||
@@ -104,7 +111,8 @@ func (s *UserSignSession) SetUserNKode(keySelection KeySelection) (IdxInterface,
|
||||
return 0 <= i && i < s.Kp.NumbOfKeys
|
||||
})
|
||||
if !validKeySelection {
|
||||
return nil, errors.New(fmt.Sprintf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1))
|
||||
log.Printf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1)
|
||||
return nil, ErrKeyIndexOutOfRange
|
||||
}
|
||||
|
||||
s.SetKeySelection = keySelection
|
||||
@@ -134,7 +142,7 @@ func (s *UserSignSession) getSelectedKeyVals(keySelections KeySelection, userInt
|
||||
|
||||
func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, error) {
|
||||
if kp.IsDispersable() {
|
||||
return nil, errors.New("keypad is dispersable, can't use signupInterface")
|
||||
return nil, ErrKeypadIsNotDispersible
|
||||
}
|
||||
err := baseUserInterface.RandomShuffle()
|
||||
if err != nil {
|
||||
|
||||
@@ -78,9 +78,9 @@ func TestApi(t *testing.T) {
|
||||
|
||||
var jwtTokens core.AuthenticationTokens
|
||||
testApiPost(t, base+core.Login, loginBody, &jwtTokens)
|
||||
refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken)
|
||||
refreshClaims, err := core.ParseRegisteredClaimToken(jwtTokens.RefreshToken)
|
||||
assert.Equal(t, refreshClaims.Subject, userEmail)
|
||||
accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken)
|
||||
accessClaims, err := core.ParseRegisteredClaimToken(jwtTokens.AccessToken)
|
||||
assert.Equal(t, accessClaims.Subject, userEmail)
|
||||
renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId}
|
||||
testApiPost(t, base+core.RenewAttributes, renewBody, nil)
|
||||
@@ -102,7 +102,7 @@ func TestApi(t *testing.T) {
|
||||
var refreshTokenResp core.RefreshTokenResp
|
||||
|
||||
testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken)
|
||||
accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken)
|
||||
accessClaims, err = core.ParseRegisteredClaimToken(refreshTokenResp.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessClaims.Subject, userEmail)
|
||||
}
|
||||
|
||||
50
util/util.go
50
util/util.go
@@ -6,8 +6,8 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go-nkode/hashset"
|
||||
"log"
|
||||
"math/big"
|
||||
r "math/rand"
|
||||
"time"
|
||||
@@ -17,12 +17,25 @@ type ShuffleTypes interface {
|
||||
[]int | int | []uint64 | uint64
|
||||
}
|
||||
|
||||
// fisherYatesShuffle shuffles a slice of bytes in place using the Fisher-Yates algorithm.
|
||||
var (
|
||||
ErrFisherYatesShuffle = errors.New("unable to shuffle array")
|
||||
ErrRandomBytes = errors.New("random bytes error")
|
||||
ErrRandNonRepeatingUint64 = errors.New("list length must be less than 2^32")
|
||||
ErrParseHexString = errors.New("parse hex string error")
|
||||
ErrMatrixTranspose = errors.New("matrix cannot be nil or empty")
|
||||
ErrListToMatrixNotDivisible = errors.New("list to matrix not possible")
|
||||
ErrElementNotInArray = errors.New("element not in array")
|
||||
ErrDecodeBase64Str = errors.New("decode base64 err")
|
||||
ErrRandNonRepeatingInt = errors.New("list length must be less than 2^31")
|
||||
ErrXorLengthMismatch = errors.New("xor length mismatch")
|
||||
)
|
||||
|
||||
func fisherYatesShuffle[T ShuffleTypes](b *[]T) error {
|
||||
for i := len(*b) - 1; i > 0; i-- {
|
||||
bigJ, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
||||
if err != nil {
|
||||
return err
|
||||
log.Print("fisher yates shuffle error: ", err)
|
||||
return ErrFisherYatesShuffle
|
||||
}
|
||||
j := bigJ.Int64()
|
||||
(*b)[i], (*b)[j] = (*b)[j], (*b)[i]
|
||||
@@ -38,7 +51,8 @@ func RandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Print("error in random bytes: ", err)
|
||||
return nil, ErrRandomBytes
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
@@ -72,7 +86,7 @@ func GenerateRandomInt() (int, error) {
|
||||
|
||||
func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) {
|
||||
if listLen > int(1)<<32 {
|
||||
return nil, errors.New("list length must be less than 2^32")
|
||||
return nil, ErrRandNonRepeatingUint64
|
||||
}
|
||||
listSet := make(hashset.Set[uint64])
|
||||
for {
|
||||
@@ -92,7 +106,7 @@ func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) {
|
||||
|
||||
func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) {
|
||||
if listLen > int(1)<<31 {
|
||||
return nil, errors.New("list length must be less than 2^31")
|
||||
return nil, ErrRandNonRepeatingInt
|
||||
}
|
||||
listSet := make(hashset.Set[int])
|
||||
for {
|
||||
@@ -112,7 +126,8 @@ func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) {
|
||||
|
||||
func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) {
|
||||
if len(l0) != len(l1) {
|
||||
return nil, errors.New(fmt.Sprintf("list len mismatch %d, %d", len(l0), len(l1)))
|
||||
log.Printf("list len mismatch %d, %d", len(l0), len(l1))
|
||||
return nil, ErrXorLengthMismatch
|
||||
}
|
||||
|
||||
xorList := make([]uint64, len(l0))
|
||||
@@ -131,7 +146,8 @@ func EncodeBase64Str(data []uint64) string {
|
||||
func DecodeBase64Str(encoded string) ([]uint64, error) {
|
||||
dataBytes, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Print("error decoding base64 str: ", err)
|
||||
return nil, ErrDecodeBase64Str
|
||||
}
|
||||
data := ByteArrToUint64Arr(dataBytes)
|
||||
return data, nil
|
||||
@@ -179,13 +195,13 @@ func ByteArrToIntArr(byteArr []byte) []int {
|
||||
return intArr
|
||||
}
|
||||
|
||||
func IndexOf[T uint64 | int](arr []T, el T) int {
|
||||
func IndexOf[T uint64 | int](arr []T, el T) (int, error) {
|
||||
for idx, val := range arr {
|
||||
if val == el {
|
||||
return idx
|
||||
return idx, nil
|
||||
}
|
||||
}
|
||||
return -1
|
||||
return -1, ErrElementNotInArray
|
||||
}
|
||||
|
||||
func IdentityArray(arrLen int) []int {
|
||||
@@ -199,7 +215,8 @@ func IdentityArray(arrLen int) []int {
|
||||
|
||||
func ListToMatrix(listArr []int, numbCols int) ([][]int, error) {
|
||||
if len(listArr)%numbCols != 0 {
|
||||
panic(fmt.Sprintf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols))
|
||||
log.Printf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols)
|
||||
return nil, ErrListToMatrixNotDivisible
|
||||
}
|
||||
numbRows := len(listArr) / numbCols
|
||||
matrix := make([][]int, numbRows)
|
||||
@@ -213,7 +230,8 @@ func ListToMatrix(listArr []int, numbCols int) ([][]int, error) {
|
||||
|
||||
func MatrixTranspose(matrix [][]int) ([][]int, error) {
|
||||
if matrix == nil || len(matrix) == 0 {
|
||||
return nil, fmt.Errorf("matrix cannot be nil or empty")
|
||||
log.Print("can't transpose nil or zero len matrix")
|
||||
return nil, ErrMatrixTranspose
|
||||
}
|
||||
|
||||
rows := len(matrix)
|
||||
@@ -222,7 +240,8 @@ func MatrixTranspose(matrix [][]int) ([][]int, error) {
|
||||
// Check if the matrix is not rectangular
|
||||
for _, row := range matrix {
|
||||
if len(row) != cols {
|
||||
return nil, fmt.Errorf("all rows must have the same number of columns")
|
||||
log.Print("all rows must have the same number of columns")
|
||||
return nil, ErrMatrixTranspose
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,7 +286,8 @@ func ParseHexString(hexStr string) ([]byte, error) {
|
||||
// Decode the hex string into bytes
|
||||
bytes, err := hex.DecodeString(hexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Print("parse hex string err: ", err)
|
||||
return nil, ErrParseHexString
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user