refactor errors

This commit is contained in:
2024-10-14 13:29:05 -05:00
parent 1e33a81a2c
commit 39d4a1e7f0
20 changed files with 398 additions and 444 deletions

62
core/constants.go Normal file
View File

@@ -0,0 +1,62 @@
package core
import (
"errors"
"net/http"
)
var (
ErrInvalidNKodeLength = errors.New("invalid nKode length")
ErrInvalidNKodeIdx = errors.New("invalid passcode attribute index")
ErrTooFewDistinctSet = errors.New("too few distinct sets")
ErrTooFewDistinctAttributes = errors.New("too few distinct attributes")
ErrEmailAlreadySent = errors.New("email already sent")
ErrClaimExpOrNil = errors.New("claim expired or nil")
ErrInvalidJwt = errors.New("invalid jwt")
ErrInvalidKeypadDimensions = errors.New("keypad dimensions out of range")
ErrUserAlreadyExists = errors.New("user already exists")
ErrSignupSessionDNE = errors.New("signup session does not exist")
ErrUserForCustomerDNE = errors.New("user for customer does not exist")
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
ErrCustomerDne = errors.New("customer dne")
ErrSvgDne = errors.New("svg dne")
ErrStoppingDatabase = errors.New("stopping database")
ErrSqliteTx = errors.New("sqlite begin, exec, query, or commit error. see logs")
ErrEmptySvgTable = errors.New("empty svg_icon table")
ErrKeyIndexOutOfRange = errors.New("one or more keys is out of range")
ErrAttributeIndexOutOfRange = errors.New("attribute index out of range")
ErrInternalValidKeyEntry = errors.New("internal validation error")
ErrUserMaskTooLong = errors.New("user mask length exceeds max nkode length")
ErrInterfaceNotDispersible = errors.New("interface is not dispersible")
ErrIncompleteUserSignupSession = errors.New("incomplete user signup session")
ErrSetConfirmSignupMismatch = errors.New("set and confirm nkode are not the same")
ErrKeypadIsNotDispersible = errors.New("keypad is not dispersible")
)
var HttpErrMap = map[error]int{
ErrInvalidNKodeLength: http.StatusBadRequest,
ErrInvalidNKodeIdx: http.StatusBadRequest,
ErrTooFewDistinctSet: http.StatusBadRequest,
ErrTooFewDistinctAttributes: http.StatusBadRequest,
ErrEmailAlreadySent: http.StatusBadRequest,
ErrClaimExpOrNil: http.StatusForbidden,
ErrInvalidJwt: http.StatusForbidden,
ErrInvalidKeypadDimensions: http.StatusBadRequest,
ErrUserAlreadyExists: http.StatusBadRequest,
ErrSignupSessionDNE: http.StatusBadRequest,
ErrUserForCustomerDNE: http.StatusBadRequest,
ErrRefreshTokenInvalid: http.StatusForbidden,
ErrCustomerDne: http.StatusBadRequest,
ErrSvgDne: http.StatusBadRequest,
ErrStoppingDatabase: http.StatusInternalServerError,
ErrSqliteTx: http.StatusInternalServerError,
ErrEmptySvgTable: http.StatusInternalServerError,
ErrKeyIndexOutOfRange: http.StatusBadRequest,
ErrAttributeIndexOutOfRange: http.StatusInternalServerError,
ErrInternalValidKeyEntry: http.StatusInternalServerError,
ErrUserMaskTooLong: http.StatusInternalServerError,
ErrInterfaceNotDispersible: http.StatusInternalServerError,
ErrIncompleteUserSignupSession: http.StatusBadRequest,
ErrSetConfirmSignupMismatch: http.StatusBadRequest,
ErrKeypadIsNotDispersible: http.StatusInternalServerError,
}

View File

@@ -1,11 +1,8 @@
package core package core
import ( import (
"errors"
"fmt"
"github.com/google/uuid" "github.com/google/uuid"
"go-nkode/hashset" "go-nkode/hashset"
py "go-nkode/py-builtin"
"go-nkode/util" "go-nkode/util"
) )
@@ -31,16 +28,12 @@ func NewCustomer(nkodePolicy NKodePolicy) (*Customer, error) {
func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error { func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error {
nkodeLen := len(passcodeAttrIdx) nkodeLen := len(passcodeAttrIdx)
if nkodeLen < c.NKodePolicy.MinNkodeLen { if nkodeLen < c.NKodePolicy.MinNkodeLen || nkodeLen > c.NKodePolicy.MaxNkodeLen {
return errors.New(fmt.Sprintf("NKode length %d is too short. Minimum nKode length is %d", nkodeLen, c.NKodePolicy.MinNkodeLen)) return ErrInvalidNKodeLength
} }
validIdx := py.All[int](passcodeAttrIdx, func(i int) bool { if validIdx := kp.ValidateAttributeIndices(passcodeAttrIdx); !validIdx {
return i >= 0 && i < kp.TotalAttrs() return ErrInvalidNKodeIdx
})
if !validIdx {
return errors.New(fmt.Sprintf("One or more idx out of range 0-%d in IsValidNKode", kp.TotalAttrs()-1))
} }
passcodeSetVals := make(hashset.Set[uint64]) passcodeSetVals := make(hashset.Set[uint64])
passcodeAttrVals := make(hashset.Set[uint64]) passcodeAttrVals := make(hashset.Set[uint64])
@@ -59,33 +52,32 @@ func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error
} }
if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets { if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets {
return errors.New(fmt.Sprintf("passcode has two few distinct sets min %d, has %d", c.NKodePolicy.DistinctSets, passcodeSetVals.Size())) return ErrTooFewDistinctSet
} }
if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes { if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes {
return errors.New(fmt.Sprintf("passcode has two few distinct attributes min %d, has %d", c.NKodePolicy.DistinctAttributes, passcodeAttrVals.Size())) return ErrTooFewDistinctAttributes
} }
return nil return nil
} }
func (c *Customer) RenewKeys() ([]uint64, []uint64) { func (c *Customer) RenewKeys() ([]uint64, []uint64, error) {
oldAttrs := make([]uint64, len(c.Attributes.AttrVals)) oldAttrs := make([]uint64, len(c.Attributes.AttrVals))
oldSets := make([]uint64, len(c.Attributes.SetVals)) oldSets := make([]uint64, len(c.Attributes.SetVals))
copy(oldAttrs, c.Attributes.AttrVals) copy(oldAttrs, c.Attributes.AttrVals)
copy(oldSets, c.Attributes.SetVals) copy(oldSets, c.Attributes.SetVals)
err := c.Attributes.Renew() if err := c.Attributes.Renew(); err != nil {
if err != nil { return nil, nil, err
panic(err)
} }
attrsXor, err := util.XorLists(oldAttrs, c.Attributes.AttrVals) attrsXor, err := util.XorLists(oldAttrs, c.Attributes.AttrVals)
if err != nil { if err != nil {
panic(err) return nil, nil, err
} }
setXor, err := util.XorLists(oldSets, c.Attributes.SetVals) setXor, err := util.XorLists(oldSets, c.Attributes.SetVals)
if err != nil { if err != nil {
panic(err) return nil, nil, err
} }
return setXor, attrsXor return setXor, attrsXor, nil
} }

View File

@@ -1,9 +1,8 @@
package core package core
import ( import (
"errors"
"fmt"
"go-nkode/util" "go-nkode/util"
"log"
) )
type CustomerAttributes struct { type CustomerAttributes struct {
@@ -12,13 +11,15 @@ type CustomerAttributes struct {
} }
func NewCustomerAttributes() (*CustomerAttributes, error) { func NewCustomerAttributes() (*CustomerAttributes, error) {
attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
if errAttr != nil { if err != nil {
return nil, errAttr log.Print("unable to generate attribute vals: ", err)
return nil, err
} }
setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
if errSet != nil { if err != nil {
return nil, errSet log.Print("unable to generate set vals: ", err)
return nil, err
} }
customerAttrs := CustomerAttributes{ customerAttrs := CustomerAttributes{
@@ -36,37 +37,33 @@ func NewCustomerAttributesFromBytes(attrBytes []byte, setBytes []byte) CustomerA
} }
func (c *CustomerAttributes) Renew() error { func (c *CustomerAttributes) Renew() error {
attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs())
if errAttr != nil { if err != nil {
return errAttr return err
} }
setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey)
if errSet != nil { if err != nil {
return errSet return err
} }
c.AttrVals = attrVals c.AttrVals = attrVals
c.SetVals = setVals c.SetVals = setVals
return nil return nil
} }
func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) int { func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) (int, error) {
// TODO: should this be mapped instead? // TODO: should this be mapped instead?
return util.IndexOf[uint64](c.AttrVals, attrVal) return util.IndexOf[uint64](c.AttrVals, attrVal)
} }
func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) { func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) {
// TODO: should this be mapped instead? // TODO: should this be mapped instead?
idx := util.IndexOf[uint64](c.SetVals, setVal) return util.IndexOf[uint64](c.SetVals, setVal)
if idx == -1 {
return -1, errors.New(fmt.Sprintf("Set Val %d is invalid", setVal))
}
return idx, nil
} }
func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad KeypadDimension) (uint64, error) { func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad KeypadDimension) (uint64, error) {
indexOfAttr := c.IndexOfAttr(attrVal) indexOfAttr, err := c.IndexOfAttr(attrVal)
if indexOfAttr == -1 { if err != nil {
return 0, errors.New(fmt.Sprintf("No attribute %d", attrVal)) return 0, err
} }
setIdx := indexOfAttr % userKeypad.AttrsPerKey setIdx := indexOfAttr % userKeypad.AttrsPerKey
return c.SetVals[setIdx], nil return c.SetVals[setIdx], nil

View File

@@ -2,7 +2,6 @@ package core
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/config"
@@ -51,9 +50,9 @@ func NewSESClient() SESClient {
} }
func (s *SESClient) SendEmail(email Email) error { func (s *SESClient) SendEmail(email Email) error {
if _, exists := s.ResetCache.Get(email.Recipient); exists { if _, exists := s.ResetCache.Get(email.Recipient); exists {
return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject) log.Printf("email already sent to %s with subject %s", email.Recipient, email.Subject)
return ErrEmailAlreadySent
} }
// Load AWS configuration // Load AWS configuration
@@ -61,7 +60,7 @@ func (s *SESClient) SendEmail(email Email) error {
if err != nil { if err != nil {
errMsg := fmt.Sprintf("unable to load SDK config, %v", err) errMsg := fmt.Sprintf("unable to load SDK config, %v", err)
log.Print(errMsg) log.Print(errMsg)
return errors.New(errMsg) return err
} }
// Create an SES client // Create an SES client
@@ -93,7 +92,9 @@ func (s *SESClient) SendEmail(email Email) error {
resp, err := sesClient.SendEmail(context.TODO(), input) resp, err := sesClient.SendEmail(context.TODO(), input)
if err != nil { if err != nil {
s.ResetCache.Delete(email.Recipient) s.ResetCache.Delete(email.Recipient)
return fmt.Errorf("failed to send email, %v", err) errMsg := fmt.Sprintf("failed to send email, %v", err)
log.Print(errMsg)
return err
} }
// Output the message ID of the sent email // Output the message ID of the sent email

View File

@@ -89,9 +89,11 @@ func (db *InMemoryDb) Renew(id CustomerId) error {
if !exists { if !exists {
return errors.New(fmt.Sprintf("customer %s does not exist", id)) return errors.New(fmt.Sprintf("customer %s does not exist", id))
} }
setXor, attrsXor := customer.RenewKeys() setXor, attrsXor, err := customer.RenewKeys()
if err != nil {
return err
}
db.Customers[id] = customer db.Customers[id] = customer
var err error
for _, user := range db.Users { for _, user := range db.Users {
if user.CustomerId == id { if user.CustomerId == id {
err = user.RenewKeys(setXor, attrsXor) err = user.RenewKeys(setXor, attrsXor)

View File

@@ -1,8 +1,6 @@
package core package core
import ( import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"go-nkode/util" "go-nkode/util"
"log" "log"
@@ -78,42 +76,37 @@ func EncodeAndSignClaims(claims jwt.Claims) (string, error) {
return token.SignedString(secret) return token.SignedString(secret)
} }
func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) { func ParseRegisteredClaimToken(token string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { return parseJwt[*jwt.RegisteredClaims](token, &jwt.RegisteredClaims{})
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
} }
func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) { func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) {
token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { return parseJwt[*ResetNKodeClaims](resetNKodeToken, &ResetNKodeClaims{})
}
func parseJwt[T *ResetNKodeClaims | *jwt.RegisteredClaims](tokenStr string, claim jwt.Claims) (T, error) {
token, err := jwt.ParseWithClaims(tokenStr, claim, func(token *jwt.Token) (interface{}, error) {
return secret, nil return secret, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err) log.Printf("error parsing refresh token: %v", err)
return nil, ErrInvalidJwt
} }
claims, ok := token.Claims.(*jwt.RegisteredClaims) claims, ok := token.Claims.(T)
if !ok { if !ok {
return nil, errors.New("unable to parse claims") return nil, ErrInvalidJwt
} }
return claims, nil return claims, nil
} }
func ClaimExpired(claims jwt.RegisteredClaims) error { func ClaimExpired(claims jwt.RegisteredClaims) error {
if claims.ExpiresAt == nil { if claims.ExpiresAt == nil {
return errors.New("claim exp is nil") return ErrClaimExpOrNil
} }
if claims.ExpiresAt.Time.After(time.Now()) { if claims.ExpiresAt.Time.After(time.Now()) {
return nil return nil
} }
return errors.New("claim expired") return ErrClaimExpOrNil
} }
func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error) { func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error) {
@@ -127,17 +120,3 @@ func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error)
} }
return EncodeAndSignClaims(resetClaims) return EncodeAndSignClaims(resetClaims)
} }
func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) {
token, err := jwt.ParseWithClaims(resetNKodeToken, &ResetNKodeClaims{}, func(token *jwt.Token) (interface{}, error) {
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("error parsing refresh token: %w", err)
}
claims, ok := token.Claims.(*ResetNKodeClaims)
if !ok {
return nil, errors.New("unable to parse claims")
}
return claims, nil
}

View File

@@ -11,11 +11,11 @@ func TestJwtClaims(t *testing.T) {
customerId := CustomerId(uuid.New()) customerId := CustomerId(uuid.New())
authTokens, err := NewAuthenticationTokens(email, customerId) authTokens, err := NewAuthenticationTokens(email, customerId)
assert.NoError(t, err) assert.NoError(t, err)
accessToken, err := ParseAccessToken(authTokens.AccessToken) accessToken, err := ParseRegisteredClaimToken(authTokens.AccessToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessToken.Subject, email) assert.Equal(t, accessToken.Subject, email)
assert.NoError(t, ClaimExpired(*accessToken)) assert.NoError(t, ClaimExpired(*accessToken))
refreshToken, err := ParseRefreshToken(authTokens.RefreshToken) refreshToken, err := ParseRegisteredClaimToken(authTokens.RefreshToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, refreshToken.Subject, email) assert.Equal(t, refreshToken.Subject, email)
assert.NoError(t, ClaimExpired(*refreshToken)) assert.NoError(t, ClaimExpired(*refreshToken))

View File

@@ -1,6 +1,8 @@
package core package core
import "errors" import (
py "go-nkode/py-builtin"
)
type KeypadDimension struct { type KeypadDimension struct {
AttrsPerKey int `json:"attrs_per_key"` AttrsPerKey int `json:"attrs_per_key"`
@@ -17,11 +19,23 @@ func (kp *KeypadDimension) IsDispersable() bool {
func (kp *KeypadDimension) IsValidKeypadDimension() error { func (kp *KeypadDimension) IsValidKeypadDimension() error {
if KeypadMin.AttrsPerKey > kp.AttrsPerKey || KeypadMax.AttrsPerKey < kp.AttrsPerKey || KeypadMin.NumbOfKeys > kp.NumbOfKeys || KeypadMax.NumbOfKeys < kp.NumbOfKeys { if KeypadMin.AttrsPerKey > kp.AttrsPerKey || KeypadMax.AttrsPerKey < kp.AttrsPerKey || KeypadMin.NumbOfKeys > kp.NumbOfKeys || KeypadMax.NumbOfKeys < kp.NumbOfKeys {
return errors.New("keypad dimensions out of range") return ErrInvalidKeypadDimensions
} }
return nil return nil
} }
func (kp *KeypadDimension) ValidKeySelections(selectedKeys []int) bool {
return py.All[int](selectedKeys, func(idx int) bool {
return 0 <= idx && idx < kp.NumbOfKeys
})
}
func (kp *KeypadDimension) ValidateAttributeIndices(attrIndicies []int) bool {
return py.All[int](attrIndicies, func(i int) bool {
return i >= 0 && i < kp.TotalAttrs()
})
}
var ( var (
KeypadMax = KeypadDimension{ KeypadMax = KeypadDimension{
AttrsPerKey: 16, AttrsPerKey: 16,

View File

@@ -1,9 +1,9 @@
package core package core
import ( import (
"errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"log"
"os" "os"
) )
@@ -43,7 +43,8 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId
return nil, err return nil, err
} }
if user != nil && !reset { if user != nil && !reset {
return nil, fmt.Errorf("user %s already exists", string(userEmail)) log.Printf("user %s already exists", string(userEmail))
return nil, ErrUserAlreadyExists
} }
svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp)
if err != nil { if err != nil {
@@ -74,7 +75,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec
} }
session, exists := n.SignupSessions[sessionId] session, exists := n.SignupSessions[sessionId]
if !exists { if !exists {
return nil, errors.New(fmt.Sprintf("session id does not exist %s", sessionId)) log.Printf("session id does not exist %s", sessionId)
return nil, ErrSignupSessionDNE
} }
confirmInterface, err := session.SetUserNKode(keySelection) confirmInterface, err := session.SetUserNKode(keySelection)
if err != nil { if err != nil {
@@ -87,7 +89,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec
func (n *NKodeAPI) ConfirmNKode(customerId CustomerId, sessionId SessionId, keySelection KeySelection) error { func (n *NKodeAPI) ConfirmNKode(customerId CustomerId, sessionId SessionId, keySelection KeySelection) error {
session, exists := n.SignupSessions[sessionId] session, exists := n.SignupSessions[sessionId]
if !exists { if !exists {
return errors.New(fmt.Sprintf("session id does not exist %s", sessionId)) log.Printf("session id does not exist %s", sessionId)
return ErrSignupSessionDNE
} }
customer, err := n.Db.GetCustomer(customerId) customer, err := n.Db.GetCustomer(customerId)
if err != nil { if err != nil {
@@ -120,7 +123,8 @@ func (n *NKodeAPI) GetLoginInterface(userEmail UserEmail, customerId CustomerId)
return nil, err return nil, err
} }
if user == nil { if user == nil {
return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) log.Printf("user %s for customer %s dne", userEmail, customerId)
return nil, ErrUserForCustomerDNE
} }
err = user.Interface.PartialInterfaceShuffle() err = user.Interface.PartialInterfaceShuffle()
if err != nil { if err != nil {
@@ -153,7 +157,8 @@ func (n *NKodeAPI) Login(customerId CustomerId, userEmail UserEmail, keySelectio
return nil, err return nil, err
} }
if user == nil { if user == nil {
return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) log.Printf("user %s for customer %s dne", userEmail, customerId)
return nil, ErrUserForCustomerDNE
} }
passcode, err := ValidKeyEntry(*user, *customer, keySelection) passcode, err := ValidKeyEntry(*user, *customer, keySelection)
if err != nil { if err != nil {
@@ -195,12 +200,13 @@ func (n *NKodeAPI) RefreshToken(userEmail UserEmail, customerId CustomerId, refr
return "", err return "", err
} }
if user == nil { if user == nil {
return "", errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) log.Printf("user %s for customer %s dne", userEmail, customerId)
return "", ErrUserForCustomerDNE
} }
if user.RefreshToken != refreshToken { if user.RefreshToken != refreshToken {
return "", errors.New("refresh token is invalid") return "", ErrRefreshTokenInvalid
} }
refreshClaims, err := ParseRefreshToken(refreshToken) refreshClaims, err := ParseRegisteredClaimToken(refreshToken)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -222,7 +228,7 @@ func (n *NKodeAPI) ResetNKode(userEmail UserEmail, customerId CustomerId) error
nkodeResetJwt, err := ResetNKodeToken(userEmail, customerId) nkodeResetJwt, err := ResetNKodeToken(userEmail, customerId)
if err != nil { if err != nil {
return errors.New(fmt.Sprintf("unable to load SDK config, %v", err)) return err
} }
frontendHost := os.Getenv("FRONTEND_HOST") frontendHost := os.Getenv("FRONTEND_HOST")
if frontendHost == "" { if frontendHost == "" {

View File

@@ -26,6 +26,12 @@ const (
ResetNKode = "/reset-nkode" ResetNKode = "/reset-nkode"
) )
const (
malformedCustomerId = "malformed customer id"
malformedUserEmail = "malformed user email"
malformedSessionId = "malformed session id"
)
func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case CreateNewCustomer: case CreateNewCustomer:
@@ -56,226 +62,148 @@ func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.Request) {
log.Print("create new customer")
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
return return
} }
var customerPost NewCustomerPost var customerPost NewCustomerPost
err := decodeJson(w, r, &customerPost) if err := decodeJson(w, r, &customerPost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := h.Api.CreateNewCustomer(customerPost.NKodePolicy, nil) customerId, err := h.Api.CreateNewCustomer(customerPost.NKodePolicy, nil)
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return return
} }
respBody := CreateNewCustomerResp{ respBody := CreateNewCustomerResp{
CustomerId: uuid.UUID(*customerId).String(), CustomerId: uuid.UUID(*customerId).String(),
} }
respBytes, err := json.Marshal(respBody) marshalAndWriteBytes(w, respBody)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter, r *http.Request) {
log.Print("signup/reset interface")
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
return return
} }
var signupResetPost GenerateSignupRestInterfacePost var signupResetPost GenerateSignupRestInterfacePost
err := decodeJson(w, r, &signupResetPost) if err := decodeJson(w, r, &signupResetPost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
kp := KeypadDimension{ kp := KeypadDimension{
AttrsPerKey: signupResetPost.AttrsPerKey, AttrsPerKey: signupResetPost.AttrsPerKey,
NumbOfKeys: signupResetPost.NumbOfKeys, NumbOfKeys: signupResetPost.NumbOfKeys,
} }
err = kp.IsValidKeypadDimension() if err := kp.IsValidKeypadDimension(); err != nil {
if err != nil { badRequest(w, "invalid keypad dimensions")
keypadSizeOutOfRange(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(signupResetPost.CustomerId) customerId, err := uuid.Parse(signupResetPost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
userEmail, err := ParseEmail(signupResetPost.UserEmail) userEmail, err := ParseEmail(signupResetPost.UserEmail)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedUserEmail)
log.Println(err)
return return
} }
resp, err := h.Api.GenerateSignupResetInterface(userEmail, CustomerId(customerId), kp, signupResetPost.Reset) resp, err := h.Api.GenerateSignupResetInterface(userEmail, CustomerId(customerId), kp, signupResetPost.Reset)
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return
}
respBytes, err := json.Marshal(resp)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
_, err = w.Write(respBytes) marshalAndWriteBytes(w, resp)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) {
log.Print("set nkode")
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
return return
} }
var setNKodePost SetNKodePost var setNKodePost SetNKodePost
err := decodeJson(w, r, &setNKodePost) if err := decodeJson(w, r, &setNKodePost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(setNKodePost.CustomerId) customerId, err := uuid.Parse(setNKodePost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
sessionId, err := uuid.Parse(setNKodePost.SessionId) sessionId, err := uuid.Parse(setNKodePost.SessionId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedSessionId)
log.Println(err)
return return
} }
confirmInterface, err := h.Api.SetNKode(CustomerId(customerId), SessionId(sessionId), setNKodePost.KeySelection) confirmInterface, err := h.Api.SetNKode(CustomerId(customerId), SessionId(sessionId), setNKodePost.KeySelection)
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return return
} }
respBody := SetNKodeResp{UserInterface: confirmInterface} respBody := SetNKodeResp{UserInterface: confirmInterface}
marshalAndWriteBytes(w, respBody)
respBytes, err := json.Marshal(respBody)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Request) {
log.Print("confirm nkode")
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
return return
} }
var confirmNKodePost ConfirmNKodePost var confirmNKodePost ConfirmNKodePost
err := decodeJson(w, r, &confirmNKodePost) if err := decodeJson(w, r, &confirmNKodePost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(confirmNKodePost.CustomerId) customerId, err := uuid.Parse(confirmNKodePost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
sessionId, err := uuid.Parse(confirmNKodePost.SessionId) sessionId, err := uuid.Parse(confirmNKodePost.SessionId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedSessionId)
log.Println(err)
return return
} }
err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection) if err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection); err != nil {
if err != nil { handleError(w, err)
internalServerErrorHandler(w)
log.Println(err)
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
return return
} }
var loginInterfacePost GetLoginInterfacePost var loginInterfacePost GetLoginInterfacePost
err := decodeJson(w, r, &loginInterfacePost) if err := decodeJson(w, r, &loginInterfacePost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(loginInterfacePost.CustomerId) customerId, err := uuid.Parse(loginInterfacePost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
userEmail, err := ParseEmail(loginInterfacePost.UserEmail) userEmail, err := ParseEmail(loginInterfacePost.UserEmail)
if err != nil {
badRequest(w, malformedUserEmail)
}
loginInterface, err := h.Api.GetLoginInterface(userEmail, CustomerId(customerId)) loginInterface, err := h.Api.GetLoginInterface(userEmail, CustomerId(customerId))
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return return
} }
respBytes, err := json.Marshal(loginInterface) marshalAndWriteBytes(w, loginInterface)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
@@ -284,40 +212,26 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var loginPost LoginPost var loginPost LoginPost
err := decodeJson(w, r, &loginPost) if err := decodeJson(w, r, &loginPost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(loginPost.CustomerId) customerId, err := uuid.Parse(loginPost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
userEmail, err := ParseEmail(loginPost.UserEmail) userEmail, err := ParseEmail(loginPost.UserEmail)
if err != nil {
badRequest(w, malformedUserEmail)
return
}
jwtTokens, err := h.Api.Login(CustomerId(customerId), userEmail, loginPost.KeySelection) jwtTokens, err := h.Api.Login(CustomerId(customerId), userEmail, loginPost.KeySelection)
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return return
} }
respBytes, err := json.Marshal(jwtTokens) marshalAndWriteBytes(w, jwtTokens)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Request) {
@@ -326,23 +240,16 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req
return return
} }
var renewAttributesPost RenewAttributesPost var renewAttributesPost RenewAttributesPost
err := decodeJson(w, r, &renewAttributesPost) if err := decodeJson(w, r, &renewAttributesPost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return return
} }
customerId, err := uuid.Parse(renewAttributesPost.CustomerId) customerId, err := uuid.Parse(renewAttributesPost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
err = h.Api.RenewAttributes(CustomerId(customerId)) if err = h.Api.RenewAttributes(CustomerId(customerId)); err != nil {
if err != nil { handleError(w, err)
internalServerErrorHandler(w)
log.Println(err)
return return
} }
@@ -355,26 +262,11 @@ func (h *NKodeHandler) RandomSvgInterfaceHandler(w http.ResponseWriter, r *http.
} }
svgs, err := h.Api.RandomSvgInterface() svgs, err := h.Api.RandomSvgInterface()
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err)
return return
} }
respBody := RandomSvgInterfaceResp{Svgs: svgs} respBody := RandomSvgInterfaceResp{Svgs: svgs}
respBytes, err := json.Marshal(respBody) marshalAndWriteBytes(w, respBody)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
@@ -383,74 +275,54 @@ func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Reques
} }
refreshToken, err := getBearerToken(r) refreshToken, err := getBearerToken(r)
if err != nil { if err != nil {
internalServerErrorHandler(w) forbidden(w)
log.Println(err)
return return
} }
refreshClaims, err := ParseRefreshToken(refreshToken) refreshClaims, err := ParseRegisteredClaimToken(refreshToken)
customerId, err := uuid.Parse(refreshClaims.Issuer) customerId, err := uuid.Parse(refreshClaims.Issuer)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
userEmail, err := ParseEmail(refreshClaims.Subject) userEmail, err := ParseEmail(refreshClaims.Subject)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedUserEmail)
log.Println(err) log.Println(err)
return return
} }
accessToken, err := h.Api.RefreshToken(userEmail, CustomerId(customerId), refreshToken) accessToken, err := h.Api.RefreshToken(userEmail, CustomerId(customerId), refreshToken)
if err != nil { if err != nil {
internalServerErrorHandler(w) handleError(w, err)
log.Println(err) log.Println(err)
return return
} }
respBytes, err := json.Marshal(RefreshTokenResp{AccessToken: accessToken}) marshalAndWriteBytes(w, RefreshTokenResp{AccessToken: accessToken})
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerErrorHandler(w)
log.Println(err)
return
}
w.WriteHeader(http.StatusOK)
} }
func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) { func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
methodNotAllowed(w) methodNotAllowed(w)
} }
log.Print("Resetting email")
var resetNKodePost ResetNKodePost var resetNKodePost ResetNKodePost
err := decodeJson(w, r, &resetNKodePost) if err := decodeJson(w, r, &resetNKodePost); err != nil {
if err != nil {
internalServerErrorHandler(w)
log.Println("error decoding reset nkode post: ", err)
return return
} }
customerId, err := uuid.Parse(resetNKodePost.CustomerId) customerId, err := uuid.Parse(resetNKodePost.CustomerId)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedCustomerId)
log.Println(err)
return return
} }
userEmail, err := ParseEmail(resetNKodePost.UserEmail) userEmail, err := ParseEmail(resetNKodePost.UserEmail)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, malformedUserEmail)
log.Println(err)
return return
} }
err = h.Api.ResetNKode(userEmail, CustomerId(customerId))
if err != nil { if err = h.Api.ResetNKode(userEmail, CustomerId(customerId)); err != nil {
internalServerErrorHandler(w) internalServerError(w)
log.Println(err) log.Println(err)
return return
} }
@@ -459,43 +331,90 @@ func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) {
func decodeJson(w http.ResponseWriter, r *http.Request, post any) error { func decodeJson(w http.ResponseWriter, r *http.Request, post any) error {
if r.Body == nil { if r.Body == nil {
invalidJson(w) badRequest(w, "unable to parse body")
return errors.New("invalid json") log.Println("error decoding json: body is nil")
return errors.New("body is nil")
} }
err := json.NewDecoder(r.Body).Decode(&post) err := json.NewDecoder(r.Body).Decode(&post)
if err != nil { if err != nil {
internalServerErrorHandler(w) badRequest(w, "unable to parse body")
log.Println("error decoding json: ", err)
return err return err
} }
return nil return nil
} }
func internalServerErrorHandler(w http.ResponseWriter) { func internalServerError(w http.ResponseWriter) {
log.Print("500 internal server error")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 Internal Server Error")) w.Write([]byte("500 Internal Server Error"))
} }
func badRequest(w http.ResponseWriter, msg string) {
log.Print("bad request: ", msg)
w.WriteHeader(http.StatusBadRequest)
if msg == "" {
w.Write([]byte("400 Bad Request"))
} else {
w.Write([]byte(msg))
}
}
func methodNotAllowed(w http.ResponseWriter) { func methodNotAllowed(w http.ResponseWriter) {
log.Print("405 method not allowed")
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("405 method not allowed")) w.Write([]byte("405 method not allowed"))
} }
func keypadSizeOutOfRange(w http.ResponseWriter) { func forbidden(w http.ResponseWriter) {
w.WriteHeader(http.StatusBadRequest) log.Print("403 forbidden")
w.Write([]byte("invalid keypad dimensions")) w.WriteHeader(http.StatusForbidden)
w.Write([]byte("403 Forbidden"))
} }
func invalidJson(w http.ResponseWriter) { func handleError(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusBadRequest) log.Print("handling error: ", err)
w.Write([]byte("invalid json")) statusCode, exists := HttpErrMap[err]
if !exists {
internalServerError(w)
return
}
switch statusCode {
case http.StatusBadRequest:
badRequest(w, err.Error())
case http.StatusForbidden:
forbidden(w)
case http.StatusInternalServerError:
internalServerError(w)
default:
log.Print("unknown error: ", err)
internalServerError(w)
}
} }
func getBearerToken(r *http.Request) (string, error) { func getBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
// Check if the Authorization header is present and starts with "Bearer " // Check if the Authorization header is present and starts with "Bearer "
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
return "", errors.New("authorization header missing or invalid") return "", errors.New("forbidden")
} }
token := strings.TrimPrefix(authHeader, "Bearer ") token := strings.TrimPrefix(authHeader, "Bearer ")
return token, nil return token, nil
} }
func marshalAndWriteBytes(w http.ResponseWriter, data any) {
respBytes, err := json.Marshal(data)
if err != nil {
internalServerError(w)
log.Println(err)
return
}
_, err = w.Write(respBytes)
if err != nil {
internalServerError(w)
log.Println(err)
return
}
}

View File

@@ -1,9 +1,5 @@
package core package core
import (
"errors"
)
type NKodePolicy struct { type NKodePolicy struct {
MaxNkodeLen int `json:"max_nkode_len"` MaxNkodeLen int `json:"max_nkode_len"`
MinNkodeLen int `json:"min_nkode_len"` MinNkodeLen int `json:"min_nkode_len"`
@@ -24,12 +20,10 @@ func NewDefaultNKodePolicy() NKodePolicy {
} }
} }
var InvalidNKodeLen = errors.New("invalid nkode length")
func (p *NKodePolicy) ValidLength(nkodeLen int) error { func (p *NKodePolicy) ValidLength(nkodeLen int) error {
if nkodeLen < p.MinNkodeLen || nkodeLen > p.MaxNkodeLen { if nkodeLen < p.MinNkodeLen || nkodeLen > p.MaxNkodeLen {
return InvalidNKodeLen return ErrInvalidNKodeLength
} }
// TODO: validate Max > Min // TODO: validate Max > Min
// Validate lockout // Validate lockout

View File

@@ -1,43 +0,0 @@
package core
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
)
type NKodeSecrets struct {
JwtSecret []byte `json:"jwt_secret"`
}
func ReadSecrets(filePath string) (NKodeSecrets, error) {
// Initialize an empty NKodeSecrets struct
var secrets NKodeSecrets
// Read the contents of the file
data, err := ioutil.ReadFile(filePath)
if err != nil {
return secrets, fmt.Errorf("error reading secrets file: %w", err)
}
// Unmarshal JSON data into the NKodeSecrets struct
err = json.Unmarshal(data, &secrets)
if err != nil {
return secrets, fmt.Errorf("error unmarshaling secrets: %w", err)
}
return secrets, nil
}
func GetJwtSecret(filePath string) []byte {
secrets, err := ReadSecrets(filePath)
if err != nil {
log.Fatal("can't read secrets: ", err)
}
if secrets.JwtSecret == nil {
log.Fatal("wt secret is nil")
}
return secrets.JwtSecret
}

View File

@@ -2,7 +2,6 @@ package core
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
@@ -125,7 +124,10 @@ func (d *SqliteDB) Renew(id CustomerId) error {
if err != nil { if err != nil {
return err return err
} }
setXor, attrXor := customer.RenewKeys() setXor, attrXor, err := customer.RenewKeys()
if err != nil {
return err
}
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx // TODO: replace with tx
renewQuery := ` renewQuery := `
@@ -208,9 +210,13 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
}() }()
selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?` selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?`
rows, err := tx.Query(selectCustomer, uuid.UUID(id)) rows, err := tx.Query(selectCustomer, uuid.UUID(id))
if err != nil {
return nil, err
}
if !rows.Next() { if !rows.Next() {
return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err())) log.Printf("no new row for customer %s with err %s", id, rows.Err())
return nil, ErrCustomerDne
} }
var maxNKodeLen int var maxNKodeLen int
@@ -225,10 +231,6 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id))
}
customer := Customer{ customer := Customer{
Id: id, Id: id,
NKodePolicy: NKodePolicy{ NKodePolicy: NKodePolicy{
@@ -241,9 +243,8 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
}, },
Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues), Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues),
} }
err = tx.Commit() if err = tx.Commit(); err != nil {
if err != nil { return nil, err
return nil, fmt.Errorf("read customer won't commit %w", err)
} }
return &customer, nil return &customer, nil
} }
@@ -278,9 +279,6 @@ WHERE user.username = ? AND user.customer_id = ?
var svgIdInterface []byte var svgIdInterface []byte
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId))
}
userId, err := uuid.Parse(id) userId, err := uuid.Parse(id)
if err != nil { if err != nil {
@@ -324,8 +322,7 @@ WHERE user.username = ? AND user.customer_id = ?
} }
user.Interface.Kp = &user.Kp user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp
err = tx.Commit() if err = tx.Commit(); err != nil {
if err != nil {
return nil, err return nil, err
} }
return &user, nil return &user, nil
@@ -360,15 +357,14 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
return nil, err return nil, err
} }
if !rows.Next() { if !rows.Next() {
return nil, errors.New(fmt.Sprintf("id not found: %d", id)) log.Printf("id not found: %d", id)
return nil, ErrSvgDne
} }
err = rows.Scan(&svgs[idx]) if err = rows.Scan(&svgs[idx]); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
err = tx.Commit() if err = tx.Commit(); err != nil {
if err != nil {
return nil, err return nil, err
} }
return svgs, nil return svgs, nil
@@ -383,16 +379,14 @@ func (d *SqliteDB) writeToDb(query string, args []any) error {
if err != nil { if err != nil {
err = tx.Rollback() err = tx.Rollback()
if err != nil { if err != nil {
log.Fatal(fmt.Sprintf("Write won't roll back %+v", err)) log.Fatalf("fatal error: write won't roll back %+v", err)
} }
} }
}() }()
_, err = tx.Exec(query, args...) if _, err = tx.Exec(query, args...); err != nil {
if err != nil {
return err return err
} }
err = tx.Commit() if err = tx.Commit(); err != nil {
if err != nil {
return err return err
} }
return nil return nil
@@ -400,7 +394,7 @@ func (d *SqliteDB) writeToDb(query string, args []any) error {
func (d *SqliteDB) addWriteTx(query string, args []any) error { func (d *SqliteDB) addWriteTx(query string, args []any) error {
if d.stop { if d.stop {
return errors.New("stopping database") return ErrStoppingDatabase
} }
errChan := make(chan error) errChan := make(chan error)
writeTx := WriteTx{ writeTx := WriteTx{
@@ -416,31 +410,37 @@ func (d *SqliteDB) addWriteTx(query string, args []any) error {
func (d *SqliteDB) getRandomIds(count int) ([]int, error) { func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
tx, err := d.db.Begin() tx, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err log.Print(err)
return nil, ErrSqliteTx
} }
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
if err != nil { if err != nil {
return nil, err log.Print(err)
return nil, ErrSqliteTx
} }
var tableLen int var tableLen int
if !rows.Next() { if !rows.Next() {
return nil, errors.New("empty svg_icon table") return nil, ErrEmptySvgTable
} }
if err = rows.Scan(&tableLen); err != nil { if err = rows.Scan(&tableLen); err != nil {
return nil, err log.Print(err)
return nil, ErrSqliteTx
} }
perm, err := util.RandomPermutation(tableLen) perm, err := util.RandomPermutation(tableLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for idx := range perm { for idx := range perm {
perm[idx] += 1 perm[idx] += 1
} }
if err = tx.Commit(); err != nil { if err = tx.Commit(); err != nil {
return nil, err log.Print(err)
return nil, ErrSqliteTx
} }
return perm[:count], nil return perm[:count], nil
} }

View File

@@ -9,9 +9,9 @@ import (
func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize KeypadDimension) ([]int, error) { func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize KeypadDimension) ([]int, error) {
selectedKeys := make([]int, len(passcodeIdxs)) selectedKeys := make([]int, len(passcodeIdxs))
for idx := range passcodeIdxs { for idx := range passcodeIdxs {
attrIdx := util.IndexOf[int](interfaceUser, passcodeIdxs[idx]) attrIdx, err := util.IndexOf[int](interfaceUser, passcodeIdxs[idx])
if attrIdx == -1 { if err != nil {
return nil, errors.New(fmt.Sprintf("index: %d out of range 0-%d", passcodeIdxs[idx], keypadSize.TotalAttrs()-1)) return nil, err
} }
keyNumb := attrIdx / keypadSize.AttrsPerKey keyNumb := attrIdx / keypadSize.AttrsPerKey
if keyNumb >= keypadSize.NumbOfKeys { if keyNumb >= keypadSize.NumbOfKeys {

View File

@@ -1,10 +1,9 @@
package core package core
import ( import (
"errors"
"github.com/google/uuid" "github.com/google/uuid"
"go-nkode/py-builtin"
"go-nkode/util" "go-nkode/util"
"log"
) )
type User struct { type User struct {
@@ -61,31 +60,27 @@ func (u *User) GetLoginInterface() ([]int, error) {
return u.Interface.IdxInterface, nil return u.Interface.IdxInterface, nil
} }
var KeyIndexOutOfRange = errors.New("one or more keys is out of range")
func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, error) { func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, error) {
validKeys := py_builtin.All[int](selectedKeys, func(idx int) bool { if validKeys := user.Kp.ValidKeySelections(selectedKeys); !validKeys {
return 0 <= idx && idx < user.Kp.NumbOfKeys
}) return nil, ErrKeyIndexOutOfRange
if !validKeys {
panic(KeyIndexOutOfRange)
} }
var err error
passcodeLen := len(selectedKeys) passcodeLen := len(selectedKeys)
err = customer.NKodePolicy.ValidLength(passcodeLen) if err := customer.NKodePolicy.ValidLength(passcodeLen); err != nil {
if err != nil {
return nil, err return nil, err
} }
setVals, err := customer.Attributes.SetValsForKp(user.Kp) setVals, err := customer.Attributes.SetValsForKp(user.Kp)
if err != nil { if err != nil {
return nil, err log.Printf("fatal error in validate key entry;invalid user keypad dimensions for user %s with error %v", user.Email, err)
return nil, ErrInternalValidKeyEntry
} }
passcodeSetVals, err := user.DecipherMask(setVals, passcodeLen) passcodeSetVals, err := user.DecipherMask(setVals, passcodeLen)
if err != nil { if err != nil {
return nil, err log.Printf("fatal error in validate key entry;something when wrong deciphering mask;user email %s; error %v", user.Email, err)
return nil, ErrInternalValidKeyEntry
} }
presumedAttrIdxVals := make([]int, passcodeLen) presumedAttrIdxVals := make([]int, passcodeLen)
@@ -93,11 +88,13 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err
keyNumb := selectedKeys[idx] keyNumb := selectedKeys[idx]
setIdx, err := customer.Attributes.IndexOfSet(passcodeSetVals[idx]) setIdx, err := customer.Attributes.IndexOfSet(passcodeSetVals[idx])
if err != nil { if err != nil {
return nil, err log.Printf("fatal error in validate key entry;something when wrong getting the IndexOfSet;user email %s; error %v", user.Email, err)
return nil, ErrInternalValidKeyEntry
} }
selectedAttrIdx, err := user.Interface.GetAttrIdxByKeyNumbSetIdx(setIdx, keyNumb) selectedAttrIdx, err := user.Interface.GetAttrIdxByKeyNumbSetIdx(setIdx, keyNumb)
if err != nil { if err != nil {
return nil, err log.Printf("fatal error in validate key entry;something when wrong getting the GetAttrIdxByKeyNumbSetIdx;user email %s; error %v", user.Email, err)
return nil, ErrInternalValidKeyEntry
} }
presumedAttrIdxVals[idx] = selectedAttrIdx presumedAttrIdxVals[idx] = selectedAttrIdx
} }

View File

@@ -2,7 +2,6 @@ package core
import ( import (
"crypto/sha256" "crypto/sha256"
"errors"
"go-nkode/util" "go-nkode/util"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -49,7 +48,7 @@ func NewUserCipherKeys(kp *KeypadDimension, setVals []uint64, maxNKodeLen int) (
func (u *UserCipherKeys) PadUserMask(userMask []uint64, setVals []uint64) ([]uint64, error) { func (u *UserCipherKeys) PadUserMask(userMask []uint64, setVals []uint64) ([]uint64, error) {
if len(userMask) > u.MaxNKodeLen { if len(userMask) > u.MaxNKodeLen {
return nil, errors.New("user mask length exceeds max nkode length") return nil, ErrUserMaskTooLong
} }
paddedUserMask := make([]uint64, len(userMask)) paddedUserMask := make([]uint64, len(userMask))
copy(paddedUserMask, userMask) copy(paddedUserMask, userMask)
@@ -153,7 +152,10 @@ func (u *UserCipherKeys) DecipherMask(mask string, setVals []uint64, passcodeLen
passcodeSet := make([]uint64, passcodeLen) passcodeSet := make([]uint64, passcodeLen)
for idx, setCipher := range decipheredMask[:passcodeLen] { for idx, setCipher := range decipheredMask[:passcodeLen] {
setIdx := util.IndexOf(setKeyRandComponent, setCipher) setIdx, err := util.IndexOf(setKeyRandComponent, setCipher)
if err != nil {
return nil, err
}
passcodeSet[idx] = setVals[setIdx] passcodeSet[idx] = setVals[setIdx]
} }
return passcodeSet, nil return passcodeSet, nil
@@ -175,6 +177,9 @@ func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs Cust
} }
} }
mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.Kp) mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.Kp)
if err != nil {
return nil, err
}
encipheredCode := EncipheredNKode{ encipheredCode := EncipheredNKode{
Code: code, Code: code,
Mask: mask, Mask: mask,

View File

@@ -1,10 +1,9 @@
package core package core
import ( import (
"errors"
"fmt"
"go-nkode/hashset" "go-nkode/hashset"
"go-nkode/util" "go-nkode/util"
"log"
) )
type UserInterface struct { type UserInterface struct {
@@ -70,7 +69,7 @@ func (u *UserInterface) SetViewMatrix() ([][]int, error) {
func (u *UserInterface) DisperseInterface() error { func (u *UserInterface) DisperseInterface() error {
if !u.Kp.IsDispersable() { if !u.Kp.IsDispersable() {
panic("interface is not dispersable") return ErrInterfaceNotDispersible
} }
err := u.shuffleKeys() err := u.shuffleKeys()
@@ -180,11 +179,13 @@ func (u *UserInterface) PartialInterfaceShuffle() error {
func (u *UserInterface) GetAttrIdxByKeyNumbSetIdx(setIdx int, keyNumb int) (int, error) { func (u *UserInterface) GetAttrIdxByKeyNumbSetIdx(setIdx int, keyNumb int) (int, error) {
if keyNumb < 0 || u.Kp.NumbOfKeys <= keyNumb { if keyNumb < 0 || u.Kp.NumbOfKeys <= keyNumb {
return -1, errors.New(fmt.Sprintf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys)) log.Printf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys)
return -1, ErrKeyIndexOutOfRange
} }
if setIdx < 0 || u.Kp.AttrsPerKey <= setIdx { if setIdx < 0 || u.Kp.AttrsPerKey <= setIdx {
return -1, errors.New(fmt.Sprintf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey)) log.Printf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey)
return -1, ErrAttributeIndexOutOfRange
} }
keypadView, err := u.InterfaceMatrix() keypadView, err := u.InterfaceMatrix()
if err != nil { if err != nil {

View File

@@ -1,12 +1,11 @@
package core package core
import ( import (
"errors"
"fmt"
"github.com/google/uuid" "github.com/google/uuid"
"go-nkode/hashset" "go-nkode/hashset"
py "go-nkode/py-builtin" py "go-nkode/py-builtin"
"go-nkode/util" "go-nkode/util"
"log"
) )
type UserSignSession struct { type UserSignSession struct {
@@ -52,27 +51,33 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e
}) })
if !validEntry { if !validEntry {
return nil, errors.New(fmt.Sprintf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys)) log.Printf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys)
return nil, ErrKeyIndexOutOfRange
} }
if s.SetIdxInterface == nil { if s.SetIdxInterface == nil {
return nil, errors.New("signup session set interface is nil") log.Print("signup session set interface is nil")
return nil, ErrIncompleteUserSignupSession
} }
if s.ConfirmIdxInterface == nil { if s.ConfirmIdxInterface == nil {
return nil, errors.New("signup session confirm interface is nil") log.Print("signup session confirm interface is nil")
return nil, ErrIncompleteUserSignupSession
} }
if s.SetKeySelection == nil { if s.SetKeySelection == nil {
return nil, errors.New("signup session set key entry is nil") log.Print("signup session set key entry is nil")
return nil, ErrIncompleteUserSignupSession
} }
if s.UserEmail == "" { if s.UserEmail == "" {
return nil, errors.New("signup session username is nil") log.Print("signup session username is nil")
return nil, ErrIncompleteUserSignupSession
} }
if len(confirmKeyEntry) != len(s.SetKeySelection) { if len(confirmKeyEntry) != len(s.SetKeySelection) {
return nil, errors.New(fmt.Sprintf("confirm and set key entry lenght mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection))) log.Printf("confirm and set key entry length mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection))
return nil, ErrSetConfirmSignupMismatch
} }
passcodeLen := len(confirmKeyEntry) passcodeLen := len(confirmKeyEntry)
@@ -88,10 +93,12 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e
confirmKey := hashset.NewSetFromSlice[int](confirmKeyVals[idx]) confirmKey := hashset.NewSetFromSlice[int](confirmKeyVals[idx])
intersection := setKey.Intersect(confirmKey) intersection := setKey.Intersect(confirmKey)
if intersection.Size() < 1 { if intersection.Size() < 1 {
return nil, errors.New(fmt.Sprintf("set and confirm do not intersect at index %d", idx)) log.Printf("set and confirm do not intersect at index %d", idx)
return nil, ErrSetConfirmSignupMismatch
} }
if intersection.Size() > 1 { if intersection.Size() > 1 {
return nil, errors.New(fmt.Sprintf("set and confirm intersect at more than one point at index %d", idx)) log.Printf("set and confirm intersect at more than one point at index %d", idx)
return nil, ErrSetConfirmSignupMismatch
} }
intersectionSlice := intersection.ToSlice() intersectionSlice := intersection.ToSlice()
passcode[idx] = intersectionSlice[0] passcode[idx] = intersectionSlice[0]
@@ -104,7 +111,8 @@ func (s *UserSignSession) SetUserNKode(keySelection KeySelection) (IdxInterface,
return 0 <= i && i < s.Kp.NumbOfKeys return 0 <= i && i < s.Kp.NumbOfKeys
}) })
if !validKeySelection { if !validKeySelection {
return nil, errors.New(fmt.Sprintf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1)) log.Printf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1)
return nil, ErrKeyIndexOutOfRange
} }
s.SetKeySelection = keySelection s.SetKeySelection = keySelection
@@ -134,7 +142,7 @@ func (s *UserSignSession) getSelectedKeyVals(keySelections KeySelection, userInt
func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, error) { func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, error) {
if kp.IsDispersable() { if kp.IsDispersable() {
return nil, errors.New("keypad is dispersable, can't use signupInterface") return nil, ErrKeypadIsNotDispersible
} }
err := baseUserInterface.RandomShuffle() err := baseUserInterface.RandomShuffle()
if err != nil { if err != nil {

View File

@@ -78,9 +78,9 @@ func TestApi(t *testing.T) {
var jwtTokens core.AuthenticationTokens var jwtTokens core.AuthenticationTokens
testApiPost(t, base+core.Login, loginBody, &jwtTokens) testApiPost(t, base+core.Login, loginBody, &jwtTokens)
refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken) refreshClaims, err := core.ParseRegisteredClaimToken(jwtTokens.RefreshToken)
assert.Equal(t, refreshClaims.Subject, userEmail) assert.Equal(t, refreshClaims.Subject, userEmail)
accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken) accessClaims, err := core.ParseRegisteredClaimToken(jwtTokens.AccessToken)
assert.Equal(t, accessClaims.Subject, userEmail) assert.Equal(t, accessClaims.Subject, userEmail)
renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId}
testApiPost(t, base+core.RenewAttributes, renewBody, nil) testApiPost(t, base+core.RenewAttributes, renewBody, nil)
@@ -102,7 +102,7 @@ func TestApi(t *testing.T) {
var refreshTokenResp core.RefreshTokenResp var refreshTokenResp core.RefreshTokenResp
testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken) testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken)
accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken) accessClaims, err = core.ParseRegisteredClaimToken(refreshTokenResp.AccessToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessClaims.Subject, userEmail) assert.Equal(t, accessClaims.Subject, userEmail)
} }

View File

@@ -6,8 +6,8 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"go-nkode/hashset" "go-nkode/hashset"
"log"
"math/big" "math/big"
r "math/rand" r "math/rand"
"time" "time"
@@ -17,12 +17,25 @@ type ShuffleTypes interface {
[]int | int | []uint64 | uint64 []int | int | []uint64 | uint64
} }
// fisherYatesShuffle shuffles a slice of bytes in place using the Fisher-Yates algorithm. var (
ErrFisherYatesShuffle = errors.New("unable to shuffle array")
ErrRandomBytes = errors.New("random bytes error")
ErrRandNonRepeatingUint64 = errors.New("list length must be less than 2^32")
ErrParseHexString = errors.New("parse hex string error")
ErrMatrixTranspose = errors.New("matrix cannot be nil or empty")
ErrListToMatrixNotDivisible = errors.New("list to matrix not possible")
ErrElementNotInArray = errors.New("element not in array")
ErrDecodeBase64Str = errors.New("decode base64 err")
ErrRandNonRepeatingInt = errors.New("list length must be less than 2^31")
ErrXorLengthMismatch = errors.New("xor length mismatch")
)
func fisherYatesShuffle[T ShuffleTypes](b *[]T) error { func fisherYatesShuffle[T ShuffleTypes](b *[]T) error {
for i := len(*b) - 1; i > 0; i-- { for i := len(*b) - 1; i > 0; i-- {
bigJ, err := rand.Int(rand.Reader, big.NewInt(int64(i+1))) bigJ, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil { if err != nil {
return err log.Print("fisher yates shuffle error: ", err)
return ErrFisherYatesShuffle
} }
j := bigJ.Int64() j := bigJ.Int64()
(*b)[i], (*b)[j] = (*b)[j], (*b)[i] (*b)[i], (*b)[j] = (*b)[j], (*b)[i]
@@ -38,7 +51,8 @@ func RandomBytes(n int) ([]byte, error) {
b := make([]byte, n) b := make([]byte, n)
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
return nil, err log.Print("error in random bytes: ", err)
return nil, ErrRandomBytes
} }
return b, nil return b, nil
} }
@@ -72,7 +86,7 @@ func GenerateRandomInt() (int, error) {
func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) { func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) {
if listLen > int(1)<<32 { if listLen > int(1)<<32 {
return nil, errors.New("list length must be less than 2^32") return nil, ErrRandNonRepeatingUint64
} }
listSet := make(hashset.Set[uint64]) listSet := make(hashset.Set[uint64])
for { for {
@@ -92,7 +106,7 @@ func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) {
func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) { func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) {
if listLen > int(1)<<31 { if listLen > int(1)<<31 {
return nil, errors.New("list length must be less than 2^31") return nil, ErrRandNonRepeatingInt
} }
listSet := make(hashset.Set[int]) listSet := make(hashset.Set[int])
for { for {
@@ -112,7 +126,8 @@ func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) {
func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) { func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) {
if len(l0) != len(l1) { if len(l0) != len(l1) {
return nil, errors.New(fmt.Sprintf("list len mismatch %d, %d", len(l0), len(l1))) log.Printf("list len mismatch %d, %d", len(l0), len(l1))
return nil, ErrXorLengthMismatch
} }
xorList := make([]uint64, len(l0)) xorList := make([]uint64, len(l0))
@@ -131,7 +146,8 @@ func EncodeBase64Str(data []uint64) string {
func DecodeBase64Str(encoded string) ([]uint64, error) { func DecodeBase64Str(encoded string) ([]uint64, error) {
dataBytes, err := base64.StdEncoding.DecodeString(encoded) dataBytes, err := base64.StdEncoding.DecodeString(encoded)
if err != nil { if err != nil {
return nil, err log.Print("error decoding base64 str: ", err)
return nil, ErrDecodeBase64Str
} }
data := ByteArrToUint64Arr(dataBytes) data := ByteArrToUint64Arr(dataBytes)
return data, nil return data, nil
@@ -179,13 +195,13 @@ func ByteArrToIntArr(byteArr []byte) []int {
return intArr return intArr
} }
func IndexOf[T uint64 | int](arr []T, el T) int { func IndexOf[T uint64 | int](arr []T, el T) (int, error) {
for idx, val := range arr { for idx, val := range arr {
if val == el { if val == el {
return idx return idx, nil
} }
} }
return -1 return -1, ErrElementNotInArray
} }
func IdentityArray(arrLen int) []int { func IdentityArray(arrLen int) []int {
@@ -199,7 +215,8 @@ func IdentityArray(arrLen int) []int {
func ListToMatrix(listArr []int, numbCols int) ([][]int, error) { func ListToMatrix(listArr []int, numbCols int) ([][]int, error) {
if len(listArr)%numbCols != 0 { if len(listArr)%numbCols != 0 {
panic(fmt.Sprintf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols)) log.Printf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols)
return nil, ErrListToMatrixNotDivisible
} }
numbRows := len(listArr) / numbCols numbRows := len(listArr) / numbCols
matrix := make([][]int, numbRows) matrix := make([][]int, numbRows)
@@ -213,7 +230,8 @@ func ListToMatrix(listArr []int, numbCols int) ([][]int, error) {
func MatrixTranspose(matrix [][]int) ([][]int, error) { func MatrixTranspose(matrix [][]int) ([][]int, error) {
if matrix == nil || len(matrix) == 0 { if matrix == nil || len(matrix) == 0 {
return nil, fmt.Errorf("matrix cannot be nil or empty") log.Print("can't transpose nil or zero len matrix")
return nil, ErrMatrixTranspose
} }
rows := len(matrix) rows := len(matrix)
@@ -222,7 +240,8 @@ func MatrixTranspose(matrix [][]int) ([][]int, error) {
// Check if the matrix is not rectangular // Check if the matrix is not rectangular
for _, row := range matrix { for _, row := range matrix {
if len(row) != cols { if len(row) != cols {
return nil, fmt.Errorf("all rows must have the same number of columns") log.Print("all rows must have the same number of columns")
return nil, ErrMatrixTranspose
} }
} }
@@ -267,7 +286,8 @@ func ParseHexString(hexStr string) ([]byte, error) {
// Decode the hex string into bytes // Decode the hex string into bytes
bytes, err := hex.DecodeString(hexStr) bytes, err := hex.DecodeString(hexStr)
if err != nil { if err != nil {
return nil, err log.Print("parse hex string err: ", err)
return nil, ErrParseHexString
} }
return bytes, nil return bytes, nil
} }