diff --git a/core/constants.go b/core/constants.go new file mode 100644 index 0000000..4de8ea8 --- /dev/null +++ b/core/constants.go @@ -0,0 +1,62 @@ +package core + +import ( + "errors" + "net/http" +) + +var ( + ErrInvalidNKodeLength = errors.New("invalid nKode length") + ErrInvalidNKodeIdx = errors.New("invalid passcode attribute index") + ErrTooFewDistinctSet = errors.New("too few distinct sets") + ErrTooFewDistinctAttributes = errors.New("too few distinct attributes") + ErrEmailAlreadySent = errors.New("email already sent") + ErrClaimExpOrNil = errors.New("claim expired or nil") + ErrInvalidJwt = errors.New("invalid jwt") + ErrInvalidKeypadDimensions = errors.New("keypad dimensions out of range") + ErrUserAlreadyExists = errors.New("user already exists") + ErrSignupSessionDNE = errors.New("signup session does not exist") + ErrUserForCustomerDNE = errors.New("user for customer does not exist") + ErrRefreshTokenInvalid = errors.New("refresh token invalid") + ErrCustomerDne = errors.New("customer dne") + ErrSvgDne = errors.New("svg dne") + ErrStoppingDatabase = errors.New("stopping database") + ErrSqliteTx = errors.New("sqlite begin, exec, query, or commit error. see logs") + ErrEmptySvgTable = errors.New("empty svg_icon table") + ErrKeyIndexOutOfRange = errors.New("one or more keys is out of range") + ErrAttributeIndexOutOfRange = errors.New("attribute index out of range") + ErrInternalValidKeyEntry = errors.New("internal validation error") + ErrUserMaskTooLong = errors.New("user mask length exceeds max nkode length") + ErrInterfaceNotDispersible = errors.New("interface is not dispersible") + ErrIncompleteUserSignupSession = errors.New("incomplete user signup session") + ErrSetConfirmSignupMismatch = errors.New("set and confirm nkode are not the same") + ErrKeypadIsNotDispersible = errors.New("keypad is not dispersible") +) + +var HttpErrMap = map[error]int{ + ErrInvalidNKodeLength: http.StatusBadRequest, + ErrInvalidNKodeIdx: http.StatusBadRequest, + ErrTooFewDistinctSet: http.StatusBadRequest, + ErrTooFewDistinctAttributes: http.StatusBadRequest, + ErrEmailAlreadySent: http.StatusBadRequest, + ErrClaimExpOrNil: http.StatusForbidden, + ErrInvalidJwt: http.StatusForbidden, + ErrInvalidKeypadDimensions: http.StatusBadRequest, + ErrUserAlreadyExists: http.StatusBadRequest, + ErrSignupSessionDNE: http.StatusBadRequest, + ErrUserForCustomerDNE: http.StatusBadRequest, + ErrRefreshTokenInvalid: http.StatusForbidden, + ErrCustomerDne: http.StatusBadRequest, + ErrSvgDne: http.StatusBadRequest, + ErrStoppingDatabase: http.StatusInternalServerError, + ErrSqliteTx: http.StatusInternalServerError, + ErrEmptySvgTable: http.StatusInternalServerError, + ErrKeyIndexOutOfRange: http.StatusBadRequest, + ErrAttributeIndexOutOfRange: http.StatusInternalServerError, + ErrInternalValidKeyEntry: http.StatusInternalServerError, + ErrUserMaskTooLong: http.StatusInternalServerError, + ErrInterfaceNotDispersible: http.StatusInternalServerError, + ErrIncompleteUserSignupSession: http.StatusBadRequest, + ErrSetConfirmSignupMismatch: http.StatusBadRequest, + ErrKeypadIsNotDispersible: http.StatusInternalServerError, +} diff --git a/core/customer.go b/core/customer.go index 005f63d..2cc3c28 100644 --- a/core/customer.go +++ b/core/customer.go @@ -1,11 +1,8 @@ package core import ( - "errors" - "fmt" "github.com/google/uuid" "go-nkode/hashset" - py "go-nkode/py-builtin" "go-nkode/util" ) @@ -31,16 +28,12 @@ func NewCustomer(nkodePolicy NKodePolicy) (*Customer, error) { func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error { nkodeLen := len(passcodeAttrIdx) - if nkodeLen < c.NKodePolicy.MinNkodeLen { - return errors.New(fmt.Sprintf("NKode length %d is too short. Minimum nKode length is %d", nkodeLen, c.NKodePolicy.MinNkodeLen)) + if nkodeLen < c.NKodePolicy.MinNkodeLen || nkodeLen > c.NKodePolicy.MaxNkodeLen { + return ErrInvalidNKodeLength } - validIdx := py.All[int](passcodeAttrIdx, func(i int) bool { - return i >= 0 && i < kp.TotalAttrs() - }) - - if !validIdx { - return errors.New(fmt.Sprintf("One or more idx out of range 0-%d in IsValidNKode", kp.TotalAttrs()-1)) + if validIdx := kp.ValidateAttributeIndices(passcodeAttrIdx); !validIdx { + return ErrInvalidNKodeIdx } passcodeSetVals := make(hashset.Set[uint64]) passcodeAttrVals := make(hashset.Set[uint64]) @@ -59,33 +52,32 @@ func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error } if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets { - return errors.New(fmt.Sprintf("passcode has two few distinct sets min %d, has %d", c.NKodePolicy.DistinctSets, passcodeSetVals.Size())) + return ErrTooFewDistinctSet } if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes { - return errors.New(fmt.Sprintf("passcode has two few distinct attributes min %d, has %d", c.NKodePolicy.DistinctAttributes, passcodeAttrVals.Size())) + return ErrTooFewDistinctAttributes } return nil } -func (c *Customer) RenewKeys() ([]uint64, []uint64) { +func (c *Customer) RenewKeys() ([]uint64, []uint64, error) { oldAttrs := make([]uint64, len(c.Attributes.AttrVals)) oldSets := make([]uint64, len(c.Attributes.SetVals)) copy(oldAttrs, c.Attributes.AttrVals) copy(oldSets, c.Attributes.SetVals) - err := c.Attributes.Renew() - if err != nil { - panic(err) + if err := c.Attributes.Renew(); err != nil { + return nil, nil, err } attrsXor, err := util.XorLists(oldAttrs, c.Attributes.AttrVals) if err != nil { - panic(err) + return nil, nil, err } setXor, err := util.XorLists(oldSets, c.Attributes.SetVals) if err != nil { - panic(err) + return nil, nil, err } - return setXor, attrsXor + return setXor, attrsXor, nil } diff --git a/core/customer_attributes.go b/core/customer_attributes.go index 518e2ce..e25b5d2 100644 --- a/core/customer_attributes.go +++ b/core/customer_attributes.go @@ -1,9 +1,8 @@ package core import ( - "errors" - "fmt" "go-nkode/util" + "log" ) type CustomerAttributes struct { @@ -12,13 +11,15 @@ type CustomerAttributes struct { } func NewCustomerAttributes() (*CustomerAttributes, error) { - attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) - if errAttr != nil { - return nil, errAttr + attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) + if err != nil { + log.Print("unable to generate attribute vals: ", err) + return nil, err } - setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) - if errSet != nil { - return nil, errSet + setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) + if err != nil { + log.Print("unable to generate set vals: ", err) + return nil, err } customerAttrs := CustomerAttributes{ @@ -36,37 +37,33 @@ func NewCustomerAttributesFromBytes(attrBytes []byte, setBytes []byte) CustomerA } func (c *CustomerAttributes) Renew() error { - attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) - if errAttr != nil { - return errAttr + attrVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) + if err != nil { + return err } - setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) - if errSet != nil { - return errSet + setVals, err := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) + if err != nil { + return err } c.AttrVals = attrVals c.SetVals = setVals return nil } -func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) int { +func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) (int, error) { // TODO: should this be mapped instead? return util.IndexOf[uint64](c.AttrVals, attrVal) } func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) { // TODO: should this be mapped instead? - idx := util.IndexOf[uint64](c.SetVals, setVal) - if idx == -1 { - return -1, errors.New(fmt.Sprintf("Set Val %d is invalid", setVal)) - } - return idx, nil + return util.IndexOf[uint64](c.SetVals, setVal) } func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad KeypadDimension) (uint64, error) { - indexOfAttr := c.IndexOfAttr(attrVal) - if indexOfAttr == -1 { - return 0, errors.New(fmt.Sprintf("No attribute %d", attrVal)) + indexOfAttr, err := c.IndexOfAttr(attrVal) + if err != nil { + return 0, err } setIdx := indexOfAttr % userKeypad.AttrsPerKey return c.SetVals[setIdx], nil diff --git a/core/email_queue.go b/core/email_queue.go index 57fc116..76a2765 100644 --- a/core/email_queue.go +++ b/core/email_queue.go @@ -2,7 +2,6 @@ package core import ( "context" - "errors" "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -51,9 +50,9 @@ func NewSESClient() SESClient { } func (s *SESClient) SendEmail(email Email) error { - if _, exists := s.ResetCache.Get(email.Recipient); exists { - return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject) + log.Printf("email already sent to %s with subject %s", email.Recipient, email.Subject) + return ErrEmailAlreadySent } // Load AWS configuration @@ -61,7 +60,7 @@ func (s *SESClient) SendEmail(email Email) error { if err != nil { errMsg := fmt.Sprintf("unable to load SDK config, %v", err) log.Print(errMsg) - return errors.New(errMsg) + return err } // Create an SES client @@ -93,7 +92,9 @@ func (s *SESClient) SendEmail(email Email) error { resp, err := sesClient.SendEmail(context.TODO(), input) if err != nil { s.ResetCache.Delete(email.Recipient) - return fmt.Errorf("failed to send email, %v", err) + errMsg := fmt.Sprintf("failed to send email, %v", err) + log.Print(errMsg) + return err } // Output the message ID of the sent email diff --git a/core/in_memory_db.go b/core/in_memory_db.go index 1f71989..6d1c8f9 100644 --- a/core/in_memory_db.go +++ b/core/in_memory_db.go @@ -89,9 +89,11 @@ func (db *InMemoryDb) Renew(id CustomerId) error { if !exists { return errors.New(fmt.Sprintf("customer %s does not exist", id)) } - setXor, attrsXor := customer.RenewKeys() + setXor, attrsXor, err := customer.RenewKeys() + if err != nil { + return err + } db.Customers[id] = customer - var err error for _, user := range db.Users { if user.CustomerId == id { err = user.RenewKeys(setXor, attrsXor) diff --git a/core/jwt_claims.go b/core/jwt_claims.go index 3e38be9..dc43d2e 100644 --- a/core/jwt_claims.go +++ b/core/jwt_claims.go @@ -1,8 +1,6 @@ package core import ( - "errors" - "fmt" "github.com/golang-jwt/jwt/v5" "go-nkode/util" "log" @@ -78,42 +76,37 @@ func EncodeAndSignClaims(claims jwt.Claims) (string, error) { return token.SignedString(secret) } -func ParseRefreshToken(refreshToken string) (*jwt.RegisteredClaims, error) { - token, err := jwt.ParseWithClaims(refreshToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - return secret, nil - }) - if err != nil { - return nil, fmt.Errorf("error parsing refresh token: %w", err) - } - claims, ok := token.Claims.(*jwt.RegisteredClaims) - if !ok { - return nil, errors.New("unable to parse claims") - } - return claims, nil +func ParseRegisteredClaimToken(token string) (*jwt.RegisteredClaims, error) { + return parseJwt[*jwt.RegisteredClaims](token, &jwt.RegisteredClaims{}) } -func ParseAccessToken(accessToken string) (*jwt.RegisteredClaims, error) { - token, err := jwt.ParseWithClaims(accessToken, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { +func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) { + return parseJwt[*ResetNKodeClaims](resetNKodeToken, &ResetNKodeClaims{}) +} + +func parseJwt[T *ResetNKodeClaims | *jwt.RegisteredClaims](tokenStr string, claim jwt.Claims) (T, error) { + token, err := jwt.ParseWithClaims(tokenStr, claim, func(token *jwt.Token) (interface{}, error) { return secret, nil }) if err != nil { - return nil, fmt.Errorf("error parsing refresh token: %w", err) + log.Printf("error parsing refresh token: %v", err) + return nil, ErrInvalidJwt } - claims, ok := token.Claims.(*jwt.RegisteredClaims) + claims, ok := token.Claims.(T) if !ok { - return nil, errors.New("unable to parse claims") + return nil, ErrInvalidJwt } return claims, nil } func ClaimExpired(claims jwt.RegisteredClaims) error { if claims.ExpiresAt == nil { - return errors.New("claim exp is nil") + return ErrClaimExpOrNil } if claims.ExpiresAt.Time.After(time.Now()) { return nil } - return errors.New("claim expired") + return ErrClaimExpOrNil } func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error) { @@ -127,17 +120,3 @@ func ResetNKodeToken(userEmail UserEmail, customerId CustomerId) (string, error) } return EncodeAndSignClaims(resetClaims) } - -func ParseRestNKodeToken(resetNKodeToken string) (*ResetNKodeClaims, error) { - token, err := jwt.ParseWithClaims(resetNKodeToken, &ResetNKodeClaims{}, func(token *jwt.Token) (interface{}, error) { - return secret, nil - }) - if err != nil { - return nil, fmt.Errorf("error parsing refresh token: %w", err) - } - claims, ok := token.Claims.(*ResetNKodeClaims) - if !ok { - return nil, errors.New("unable to parse claims") - } - return claims, nil -} diff --git a/core/jwt_claims_test.go b/core/jwt_claims_test.go index 76ed0ba..e9aec63 100644 --- a/core/jwt_claims_test.go +++ b/core/jwt_claims_test.go @@ -11,11 +11,11 @@ func TestJwtClaims(t *testing.T) { customerId := CustomerId(uuid.New()) authTokens, err := NewAuthenticationTokens(email, customerId) assert.NoError(t, err) - accessToken, err := ParseAccessToken(authTokens.AccessToken) + accessToken, err := ParseRegisteredClaimToken(authTokens.AccessToken) assert.NoError(t, err) assert.Equal(t, accessToken.Subject, email) assert.NoError(t, ClaimExpired(*accessToken)) - refreshToken, err := ParseRefreshToken(authTokens.RefreshToken) + refreshToken, err := ParseRegisteredClaimToken(authTokens.RefreshToken) assert.NoError(t, err) assert.Equal(t, refreshToken.Subject, email) assert.NoError(t, ClaimExpired(*refreshToken)) diff --git a/core/keypad_dimension.go b/core/keypad_dimension.go index 964319e..fbc6b84 100644 --- a/core/keypad_dimension.go +++ b/core/keypad_dimension.go @@ -1,6 +1,8 @@ package core -import "errors" +import ( + py "go-nkode/py-builtin" +) type KeypadDimension struct { AttrsPerKey int `json:"attrs_per_key"` @@ -17,11 +19,23 @@ func (kp *KeypadDimension) IsDispersable() bool { func (kp *KeypadDimension) IsValidKeypadDimension() error { if KeypadMin.AttrsPerKey > kp.AttrsPerKey || KeypadMax.AttrsPerKey < kp.AttrsPerKey || KeypadMin.NumbOfKeys > kp.NumbOfKeys || KeypadMax.NumbOfKeys < kp.NumbOfKeys { - return errors.New("keypad dimensions out of range") + return ErrInvalidKeypadDimensions } return nil } +func (kp *KeypadDimension) ValidKeySelections(selectedKeys []int) bool { + return py.All[int](selectedKeys, func(idx int) bool { + return 0 <= idx && idx < kp.NumbOfKeys + }) +} + +func (kp *KeypadDimension) ValidateAttributeIndices(attrIndicies []int) bool { + return py.All[int](attrIndicies, func(i int) bool { + return i >= 0 && i < kp.TotalAttrs() + }) +} + var ( KeypadMax = KeypadDimension{ AttrsPerKey: 16, diff --git a/core/nkode_api.go b/core/nkode_api.go index 10e863d..177f725 100644 --- a/core/nkode_api.go +++ b/core/nkode_api.go @@ -1,9 +1,9 @@ package core import ( - "errors" "fmt" "github.com/google/uuid" + "log" "os" ) @@ -43,7 +43,8 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId return nil, err } if user != nil && !reset { - return nil, fmt.Errorf("user %s already exists", string(userEmail)) + log.Printf("user %s already exists", string(userEmail)) + return nil, ErrUserAlreadyExists } svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) if err != nil { @@ -74,7 +75,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec } session, exists := n.SignupSessions[sessionId] if !exists { - return nil, errors.New(fmt.Sprintf("session id does not exist %s", sessionId)) + log.Printf("session id does not exist %s", sessionId) + return nil, ErrSignupSessionDNE } confirmInterface, err := session.SetUserNKode(keySelection) if err != nil { @@ -87,7 +89,8 @@ func (n *NKodeAPI) SetNKode(customerId CustomerId, sessionId SessionId, keySelec func (n *NKodeAPI) ConfirmNKode(customerId CustomerId, sessionId SessionId, keySelection KeySelection) error { session, exists := n.SignupSessions[sessionId] if !exists { - return errors.New(fmt.Sprintf("session id does not exist %s", sessionId)) + log.Printf("session id does not exist %s", sessionId) + return ErrSignupSessionDNE } customer, err := n.Db.GetCustomer(customerId) if err != nil { @@ -120,7 +123,8 @@ func (n *NKodeAPI) GetLoginInterface(userEmail UserEmail, customerId CustomerId) return nil, err } if user == nil { - return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) + log.Printf("user %s for customer %s dne", userEmail, customerId) + return nil, ErrUserForCustomerDNE } err = user.Interface.PartialInterfaceShuffle() if err != nil { @@ -153,7 +157,8 @@ func (n *NKodeAPI) Login(customerId CustomerId, userEmail UserEmail, keySelectio return nil, err } if user == nil { - return nil, errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) + log.Printf("user %s for customer %s dne", userEmail, customerId) + return nil, ErrUserForCustomerDNE } passcode, err := ValidKeyEntry(*user, *customer, keySelection) if err != nil { @@ -195,12 +200,13 @@ func (n *NKodeAPI) RefreshToken(userEmail UserEmail, customerId CustomerId, refr return "", err } if user == nil { - return "", errors.New(fmt.Sprintf("user %s for customer %s dne", userEmail, customerId)) + log.Printf("user %s for customer %s dne", userEmail, customerId) + return "", ErrUserForCustomerDNE } if user.RefreshToken != refreshToken { - return "", errors.New("refresh token is invalid") + return "", ErrRefreshTokenInvalid } - refreshClaims, err := ParseRefreshToken(refreshToken) + refreshClaims, err := ParseRegisteredClaimToken(refreshToken) if err != nil { return "", err } @@ -222,7 +228,7 @@ func (n *NKodeAPI) ResetNKode(userEmail UserEmail, customerId CustomerId) error nkodeResetJwt, err := ResetNKodeToken(userEmail, customerId) if err != nil { - return errors.New(fmt.Sprintf("unable to load SDK config, %v", err)) + return err } frontendHost := os.Getenv("FRONTEND_HOST") if frontendHost == "" { diff --git a/core/nkode_handler.go b/core/nkode_handler.go index e7dc486..0f42edf 100644 --- a/core/nkode_handler.go +++ b/core/nkode_handler.go @@ -26,6 +26,12 @@ const ( ResetNKode = "/reset-nkode" ) +const ( + malformedCustomerId = "malformed customer id" + malformedUserEmail = "malformed user email" + malformedSessionId = "malformed session id" +) + func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case CreateNewCustomer: @@ -56,226 +62,148 @@ func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.Request) { + log.Print("create new customer") if r.Method != http.MethodPost { methodNotAllowed(w) return } var customerPost NewCustomerPost - err := decodeJson(w, r, &customerPost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &customerPost); err != nil { return } customerId, err := h.Api.CreateNewCustomer(customerPost.NKodePolicy, nil) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } respBody := CreateNewCustomerResp{ CustomerId: uuid.UUID(*customerId).String(), } - respBytes, err := json.Marshal(respBody) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, respBody) } func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter, r *http.Request) { + log.Print("signup/reset interface") if r.Method != http.MethodPost { methodNotAllowed(w) return } var signupResetPost GenerateSignupRestInterfacePost - err := decodeJson(w, r, &signupResetPost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &signupResetPost); err != nil { return } + kp := KeypadDimension{ AttrsPerKey: signupResetPost.AttrsPerKey, NumbOfKeys: signupResetPost.NumbOfKeys, } - err = kp.IsValidKeypadDimension() - if err != nil { - keypadSizeOutOfRange(w) - log.Println(err) + if err := kp.IsValidKeypadDimension(); err != nil { + badRequest(w, "invalid keypad dimensions") return } customerId, err := uuid.Parse(signupResetPost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } userEmail, err := ParseEmail(signupResetPost.UserEmail) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedUserEmail) return } resp, err := h.Api.GenerateSignupResetInterface(userEmail, CustomerId(customerId), kp, signupResetPost.Reset) if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - respBytes, err := json.Marshal(resp) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, resp) } func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { + log.Print("set nkode") if r.Method != http.MethodPost { methodNotAllowed(w) return } var setNKodePost SetNKodePost - err := decodeJson(w, r, &setNKodePost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &setNKodePost); err != nil { return } customerId, err := uuid.Parse(setNKodePost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } sessionId, err := uuid.Parse(setNKodePost.SessionId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedSessionId) return } confirmInterface, err := h.Api.SetNKode(CustomerId(customerId), SessionId(sessionId), setNKodePost.KeySelection) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } respBody := SetNKodeResp{UserInterface: confirmInterface} - - respBytes, err := json.Marshal(respBody) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, respBody) } func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Request) { + log.Print("confirm nkode") if r.Method != http.MethodPost { methodNotAllowed(w) return } var confirmNKodePost ConfirmNKodePost - err := decodeJson(w, r, &confirmNKodePost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &confirmNKodePost); err != nil { return } customerId, err := uuid.Parse(confirmNKodePost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } sessionId, err := uuid.Parse(confirmNKodePost.SessionId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedSessionId) return } - err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err = h.Api.ConfirmNKode(CustomerId(customerId), SessionId(sessionId), confirmNKodePost.KeySelection); err != nil { + handleError(w, err) return } - w.WriteHeader(http.StatusOK) - } func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { methodNotAllowed(w) return } var loginInterfacePost GetLoginInterfacePost - err := decodeJson(w, r, &loginInterfacePost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &loginInterfacePost); err != nil { return } customerId, err := uuid.Parse(loginInterfacePost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } userEmail, err := ParseEmail(loginInterfacePost.UserEmail) + if err != nil { + badRequest(w, malformedUserEmail) + } loginInterface, err := h.Api.GetLoginInterface(userEmail, CustomerId(customerId)) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } - respBytes, err := json.Marshal(loginInterface) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) - + marshalAndWriteBytes(w, loginInterface) } func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { @@ -284,40 +212,26 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { return } var loginPost LoginPost - err := decodeJson(w, r, &loginPost) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &loginPost); err != nil { return } customerId, err := uuid.Parse(loginPost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } userEmail, err := ParseEmail(loginPost.UserEmail) + if err != nil { + badRequest(w, malformedUserEmail) + return + } jwtTokens, err := h.Api.Login(CustomerId(customerId), userEmail, loginPost.KeySelection) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } - respBytes, err := json.Marshal(jwtTokens) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, jwtTokens) } func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Request) { @@ -326,23 +240,16 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req return } var renewAttributesPost RenewAttributesPost - err := decodeJson(w, r, &renewAttributesPost) - - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err := decodeJson(w, r, &renewAttributesPost); err != nil { return } customerId, err := uuid.Parse(renewAttributesPost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } - err = h.Api.RenewAttributes(CustomerId(customerId)) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) + if err = h.Api.RenewAttributes(CustomerId(customerId)); err != nil { + handleError(w, err) return } @@ -355,26 +262,11 @@ func (h *NKodeHandler) RandomSvgInterfaceHandler(w http.ResponseWriter, r *http. } svgs, err := h.Api.RandomSvgInterface() if err != nil { - internalServerErrorHandler(w) - log.Println(err) + handleError(w, err) return } respBody := RandomSvgInterfaceResp{Svgs: svgs} - respBytes, err := json.Marshal(respBody) - - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - _, err = w.Write(respBytes) - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, respBody) } func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -383,74 +275,54 @@ func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Reques } refreshToken, err := getBearerToken(r) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + forbidden(w) return } - refreshClaims, err := ParseRefreshToken(refreshToken) + refreshClaims, err := ParseRegisteredClaimToken(refreshToken) customerId, err := uuid.Parse(refreshClaims.Issuer) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } userEmail, err := ParseEmail(refreshClaims.Subject) if err != nil { - internalServerErrorHandler(w) + badRequest(w, malformedUserEmail) log.Println(err) return } accessToken, err := h.Api.RefreshToken(userEmail, CustomerId(customerId), refreshToken) if err != nil { - internalServerErrorHandler(w) + handleError(w, err) log.Println(err) return } - respBytes, err := json.Marshal(RefreshTokenResp{AccessToken: accessToken}) - - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - _, err = w.Write(respBytes) - - if err != nil { - internalServerErrorHandler(w) - log.Println(err) - return - } - w.WriteHeader(http.StatusOK) + marshalAndWriteBytes(w, RefreshTokenResp{AccessToken: accessToken}) } func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { methodNotAllowed(w) } - log.Print("Resetting email") var resetNKodePost ResetNKodePost - err := decodeJson(w, r, &resetNKodePost) - if err != nil { - internalServerErrorHandler(w) - log.Println("error decoding reset nkode post: ", err) + if err := decodeJson(w, r, &resetNKodePost); err != nil { return } + customerId, err := uuid.Parse(resetNKodePost.CustomerId) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedCustomerId) return } + userEmail, err := ParseEmail(resetNKodePost.UserEmail) if err != nil { - internalServerErrorHandler(w) - log.Println(err) + badRequest(w, malformedUserEmail) return } - err = h.Api.ResetNKode(userEmail, CustomerId(customerId)) - if err != nil { - internalServerErrorHandler(w) + + if err = h.Api.ResetNKode(userEmail, CustomerId(customerId)); err != nil { + internalServerError(w) log.Println(err) return } @@ -459,43 +331,90 @@ func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) { func decodeJson(w http.ResponseWriter, r *http.Request, post any) error { if r.Body == nil { - invalidJson(w) - return errors.New("invalid json") + badRequest(w, "unable to parse body") + log.Println("error decoding json: body is nil") + return errors.New("body is nil") } err := json.NewDecoder(r.Body).Decode(&post) if err != nil { - internalServerErrorHandler(w) + badRequest(w, "unable to parse body") + log.Println("error decoding json: ", err) return err } return nil } -func internalServerErrorHandler(w http.ResponseWriter) { +func internalServerError(w http.ResponseWriter) { + log.Print("500 internal server error") w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("500 Internal Server Error")) } +func badRequest(w http.ResponseWriter, msg string) { + log.Print("bad request: ", msg) + w.WriteHeader(http.StatusBadRequest) + if msg == "" { + w.Write([]byte("400 Bad Request")) + } else { + w.Write([]byte(msg)) + } +} + func methodNotAllowed(w http.ResponseWriter) { + log.Print("405 method not allowed") w.WriteHeader(http.StatusMethodNotAllowed) w.Write([]byte("405 method not allowed")) } -func keypadSizeOutOfRange(w http.ResponseWriter) { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("invalid keypad dimensions")) +func forbidden(w http.ResponseWriter) { + log.Print("403 forbidden") + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("403 Forbidden")) } -func invalidJson(w http.ResponseWriter) { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("invalid json")) +func handleError(w http.ResponseWriter, err error) { + log.Print("handling error: ", err) + statusCode, exists := HttpErrMap[err] + if !exists { + internalServerError(w) + return + } + switch statusCode { + case http.StatusBadRequest: + badRequest(w, err.Error()) + case http.StatusForbidden: + forbidden(w) + case http.StatusInternalServerError: + internalServerError(w) + default: + log.Print("unknown error: ", err) + internalServerError(w) + } } func getBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") // Check if the Authorization header is present and starts with "Bearer " if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { - return "", errors.New("authorization header missing or invalid") + return "", errors.New("forbidden") } token := strings.TrimPrefix(authHeader, "Bearer ") return token, nil } + +func marshalAndWriteBytes(w http.ResponseWriter, data any) { + respBytes, err := json.Marshal(data) + + if err != nil { + internalServerError(w) + log.Println(err) + return + } + _, err = w.Write(respBytes) + + if err != nil { + internalServerError(w) + log.Println(err) + return + } +} diff --git a/core/nkode_policy.go b/core/nkode_policy.go index e5be258..eb2803c 100644 --- a/core/nkode_policy.go +++ b/core/nkode_policy.go @@ -1,9 +1,5 @@ package core -import ( - "errors" -) - type NKodePolicy struct { MaxNkodeLen int `json:"max_nkode_len"` MinNkodeLen int `json:"min_nkode_len"` @@ -24,12 +20,10 @@ func NewDefaultNKodePolicy() NKodePolicy { } } -var InvalidNKodeLen = errors.New("invalid nkode length") - func (p *NKodePolicy) ValidLength(nkodeLen int) error { if nkodeLen < p.MinNkodeLen || nkodeLen > p.MaxNkodeLen { - return InvalidNKodeLen + return ErrInvalidNKodeLength } // TODO: validate Max > Min // Validate lockout diff --git a/core/secrets.go b/core/secrets.go deleted file mode 100644 index c3204f2..0000000 --- a/core/secrets.go +++ /dev/null @@ -1,43 +0,0 @@ -package core - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "log" -) - -type NKodeSecrets struct { - JwtSecret []byte `json:"jwt_secret"` -} - -func ReadSecrets(filePath string) (NKodeSecrets, error) { - // Initialize an empty NKodeSecrets struct - var secrets NKodeSecrets - - // Read the contents of the file - data, err := ioutil.ReadFile(filePath) - if err != nil { - return secrets, fmt.Errorf("error reading secrets file: %w", err) - } - - // Unmarshal JSON data into the NKodeSecrets struct - err = json.Unmarshal(data, &secrets) - if err != nil { - return secrets, fmt.Errorf("error unmarshaling secrets: %w", err) - } - - return secrets, nil -} - -func GetJwtSecret(filePath string) []byte { - secrets, err := ReadSecrets(filePath) - if err != nil { - log.Fatal("can't read secrets: ", err) - } - if secrets.JwtSecret == nil { - log.Fatal("wt secret is nil") - } - return secrets.JwtSecret - -} diff --git a/core/sqlite_db.go b/core/sqlite_db.go index 170fd98..1881ef6 100644 --- a/core/sqlite_db.go +++ b/core/sqlite_db.go @@ -2,7 +2,6 @@ package core import ( "database/sql" - "errors" "fmt" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver @@ -125,7 +124,10 @@ func (d *SqliteDB) Renew(id CustomerId) error { if err != nil { return err } - setXor, attrXor := customer.RenewKeys() + setXor, attrXor, err := customer.RenewKeys() + if err != nil { + return err + } renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} // TODO: replace with tx renewQuery := ` @@ -208,9 +210,13 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { }() selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?` rows, err := tx.Query(selectCustomer, uuid.UUID(id)) + if err != nil { + return nil, err + } if !rows.Next() { - return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err())) + log.Printf("no new row for customer %s with err %s", id, rows.Err()) + return nil, ErrCustomerDne } var maxNKodeLen int @@ -225,10 +231,6 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { if err != nil { return nil, err } - - if rows.Next() { - return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id)) - } customer := Customer{ Id: id, NKodePolicy: NKodePolicy{ @@ -241,9 +243,8 @@ func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { }, Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues), } - err = tx.Commit() - if err != nil { - return nil, fmt.Errorf("read customer won't commit %w", err) + if err = tx.Commit(); err != nil { + return nil, err } return &customer, nil } @@ -278,9 +279,6 @@ WHERE user.username = ? AND user.customer_id = ? var svgIdInterface []byte err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) - if rows.Next() { - return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) - } userId, err := uuid.Parse(id) if err != nil { @@ -324,8 +322,7 @@ WHERE user.username = ? AND user.customer_id = ? } user.Interface.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp - err = tx.Commit() - if err != nil { + if err = tx.Commit(); err != nil { return nil, err } return &user, nil @@ -360,15 +357,14 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { return nil, err } if !rows.Next() { - return nil, errors.New(fmt.Sprintf("id not found: %d", id)) + log.Printf("id not found: %d", id) + return nil, ErrSvgDne } - err = rows.Scan(&svgs[idx]) - if err != nil { + if err = rows.Scan(&svgs[idx]); err != nil { return nil, err } } - err = tx.Commit() - if err != nil { + if err = tx.Commit(); err != nil { return nil, err } return svgs, nil @@ -383,16 +379,14 @@ func (d *SqliteDB) writeToDb(query string, args []any) error { if err != nil { err = tx.Rollback() if err != nil { - log.Fatal(fmt.Sprintf("Write won't roll back %+v", err)) + log.Fatalf("fatal error: write won't roll back %+v", err) } } }() - _, err = tx.Exec(query, args...) - if err != nil { + if _, err = tx.Exec(query, args...); err != nil { return err } - err = tx.Commit() - if err != nil { + if err = tx.Commit(); err != nil { return err } return nil @@ -400,7 +394,7 @@ func (d *SqliteDB) writeToDb(query string, args []any) error { func (d *SqliteDB) addWriteTx(query string, args []any) error { if d.stop { - return errors.New("stopping database") + return ErrStoppingDatabase } errChan := make(chan error) writeTx := WriteTx{ @@ -416,31 +410,37 @@ func (d *SqliteDB) addWriteTx(query string, args []any) error { func (d *SqliteDB) getRandomIds(count int) ([]int, error) { tx, err := d.db.Begin() if err != nil { - return nil, err + log.Print(err) + return nil, ErrSqliteTx } rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") if err != nil { - return nil, err + log.Print(err) + return nil, ErrSqliteTx } var tableLen int if !rows.Next() { - return nil, errors.New("empty svg_icon table") + return nil, ErrEmptySvgTable } if err = rows.Scan(&tableLen); err != nil { - return nil, err + log.Print(err) + return nil, ErrSqliteTx } - perm, err := util.RandomPermutation(tableLen) + if err != nil { return nil, err } + for idx := range perm { perm[idx] += 1 } if err = tx.Commit(); err != nil { - return nil, err + log.Print(err) + return nil, ErrSqliteTx } + return perm[:count], nil } diff --git a/core/test_helper.go b/core/test_helper.go index 116085b..273c315 100644 --- a/core/test_helper.go +++ b/core/test_helper.go @@ -9,9 +9,9 @@ import ( func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize KeypadDimension) ([]int, error) { selectedKeys := make([]int, len(passcodeIdxs)) for idx := range passcodeIdxs { - attrIdx := util.IndexOf[int](interfaceUser, passcodeIdxs[idx]) - if attrIdx == -1 { - return nil, errors.New(fmt.Sprintf("index: %d out of range 0-%d", passcodeIdxs[idx], keypadSize.TotalAttrs()-1)) + attrIdx, err := util.IndexOf[int](interfaceUser, passcodeIdxs[idx]) + if err != nil { + return nil, err } keyNumb := attrIdx / keypadSize.AttrsPerKey if keyNumb >= keypadSize.NumbOfKeys { diff --git a/core/user.go b/core/user.go index 5b91244..618fed2 100644 --- a/core/user.go +++ b/core/user.go @@ -1,10 +1,9 @@ package core import ( - "errors" "github.com/google/uuid" - "go-nkode/py-builtin" "go-nkode/util" + "log" ) type User struct { @@ -61,31 +60,27 @@ func (u *User) GetLoginInterface() ([]int, error) { return u.Interface.IdxInterface, nil } -var KeyIndexOutOfRange = errors.New("one or more keys is out of range") - func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, error) { - validKeys := py_builtin.All[int](selectedKeys, func(idx int) bool { - return 0 <= idx && idx < user.Kp.NumbOfKeys - }) - if !validKeys { - panic(KeyIndexOutOfRange) + if validKeys := user.Kp.ValidKeySelections(selectedKeys); !validKeys { + + return nil, ErrKeyIndexOutOfRange } - var err error passcodeLen := len(selectedKeys) - err = customer.NKodePolicy.ValidLength(passcodeLen) - if err != nil { + if err := customer.NKodePolicy.ValidLength(passcodeLen); err != nil { return nil, err } setVals, err := customer.Attributes.SetValsForKp(user.Kp) if err != nil { - return nil, err + log.Printf("fatal error in validate key entry;invalid user keypad dimensions for user %s with error %v", user.Email, err) + return nil, ErrInternalValidKeyEntry } passcodeSetVals, err := user.DecipherMask(setVals, passcodeLen) if err != nil { - return nil, err + log.Printf("fatal error in validate key entry;something when wrong deciphering mask;user email %s; error %v", user.Email, err) + return nil, ErrInternalValidKeyEntry } presumedAttrIdxVals := make([]int, passcodeLen) @@ -93,11 +88,13 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err keyNumb := selectedKeys[idx] setIdx, err := customer.Attributes.IndexOfSet(passcodeSetVals[idx]) if err != nil { - return nil, err + log.Printf("fatal error in validate key entry;something when wrong getting the IndexOfSet;user email %s; error %v", user.Email, err) + return nil, ErrInternalValidKeyEntry } selectedAttrIdx, err := user.Interface.GetAttrIdxByKeyNumbSetIdx(setIdx, keyNumb) if err != nil { - return nil, err + log.Printf("fatal error in validate key entry;something when wrong getting the GetAttrIdxByKeyNumbSetIdx;user email %s; error %v", user.Email, err) + return nil, ErrInternalValidKeyEntry } presumedAttrIdxVals[idx] = selectedAttrIdx } diff --git a/core/user_cipher_keys.go b/core/user_cipher_keys.go index 022d46c..b066a4f 100644 --- a/core/user_cipher_keys.go +++ b/core/user_cipher_keys.go @@ -2,7 +2,6 @@ package core import ( "crypto/sha256" - "errors" "go-nkode/util" "golang.org/x/crypto/bcrypt" ) @@ -49,7 +48,7 @@ func NewUserCipherKeys(kp *KeypadDimension, setVals []uint64, maxNKodeLen int) ( func (u *UserCipherKeys) PadUserMask(userMask []uint64, setVals []uint64) ([]uint64, error) { if len(userMask) > u.MaxNKodeLen { - return nil, errors.New("user mask length exceeds max nkode length") + return nil, ErrUserMaskTooLong } paddedUserMask := make([]uint64, len(userMask)) copy(paddedUserMask, userMask) @@ -153,7 +152,10 @@ func (u *UserCipherKeys) DecipherMask(mask string, setVals []uint64, passcodeLen passcodeSet := make([]uint64, passcodeLen) for idx, setCipher := range decipheredMask[:passcodeLen] { - setIdx := util.IndexOf(setKeyRandComponent, setCipher) + setIdx, err := util.IndexOf(setKeyRandComponent, setCipher) + if err != nil { + return nil, err + } passcodeSet[idx] = setVals[setIdx] } return passcodeSet, nil @@ -175,6 +177,9 @@ func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs Cust } } mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.Kp) + if err != nil { + return nil, err + } encipheredCode := EncipheredNKode{ Code: code, Mask: mask, diff --git a/core/user_interface.go b/core/user_interface.go index 804210e..c09d7de 100644 --- a/core/user_interface.go +++ b/core/user_interface.go @@ -1,10 +1,9 @@ package core import ( - "errors" - "fmt" "go-nkode/hashset" "go-nkode/util" + "log" ) type UserInterface struct { @@ -70,7 +69,7 @@ func (u *UserInterface) SetViewMatrix() ([][]int, error) { func (u *UserInterface) DisperseInterface() error { if !u.Kp.IsDispersable() { - panic("interface is not dispersable") + return ErrInterfaceNotDispersible } err := u.shuffleKeys() @@ -180,11 +179,13 @@ func (u *UserInterface) PartialInterfaceShuffle() error { func (u *UserInterface) GetAttrIdxByKeyNumbSetIdx(setIdx int, keyNumb int) (int, error) { if keyNumb < 0 || u.Kp.NumbOfKeys <= keyNumb { - return -1, errors.New(fmt.Sprintf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys)) + log.Printf("keyNumb %d is out of range 0-%d", keyNumb, u.Kp.NumbOfKeys) + return -1, ErrKeyIndexOutOfRange } if setIdx < 0 || u.Kp.AttrsPerKey <= setIdx { - return -1, errors.New(fmt.Sprintf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey)) + log.Printf("setIdx %d is out of range 0-%d", setIdx, u.Kp.AttrsPerKey) + return -1, ErrAttributeIndexOutOfRange } keypadView, err := u.InterfaceMatrix() if err != nil { diff --git a/core/user_signup_session.go b/core/user_signup_session.go index b9f6666..86dd9a4 100644 --- a/core/user_signup_session.go +++ b/core/user_signup_session.go @@ -1,12 +1,11 @@ package core import ( - "errors" - "fmt" "github.com/google/uuid" "go-nkode/hashset" py "go-nkode/py-builtin" "go-nkode/util" + "log" ) type UserSignSession struct { @@ -52,27 +51,33 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e }) if !validEntry { - return nil, errors.New(fmt.Sprintf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys)) + log.Printf("Invalid Key entry. One or more key index: %#v, not in range 0-%d", confirmKeyEntry, s.Kp.NumbOfKeys) + return nil, ErrKeyIndexOutOfRange } if s.SetIdxInterface == nil { - return nil, errors.New("signup session set interface is nil") + log.Print("signup session set interface is nil") + return nil, ErrIncompleteUserSignupSession } if s.ConfirmIdxInterface == nil { - return nil, errors.New("signup session confirm interface is nil") + log.Print("signup session confirm interface is nil") + return nil, ErrIncompleteUserSignupSession } if s.SetKeySelection == nil { - return nil, errors.New("signup session set key entry is nil") + log.Print("signup session set key entry is nil") + return nil, ErrIncompleteUserSignupSession } if s.UserEmail == "" { - return nil, errors.New("signup session username is nil") + log.Print("signup session username is nil") + return nil, ErrIncompleteUserSignupSession } if len(confirmKeyEntry) != len(s.SetKeySelection) { - return nil, errors.New(fmt.Sprintf("confirm and set key entry lenght mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection))) + log.Printf("confirm and set key entry length mismatch %d != %d", len(confirmKeyEntry), len(s.SetKeySelection)) + return nil, ErrSetConfirmSignupMismatch } passcodeLen := len(confirmKeyEntry) @@ -88,10 +93,12 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e confirmKey := hashset.NewSetFromSlice[int](confirmKeyVals[idx]) intersection := setKey.Intersect(confirmKey) if intersection.Size() < 1 { - return nil, errors.New(fmt.Sprintf("set and confirm do not intersect at index %d", idx)) + log.Printf("set and confirm do not intersect at index %d", idx) + return nil, ErrSetConfirmSignupMismatch } if intersection.Size() > 1 { - return nil, errors.New(fmt.Sprintf("set and confirm intersect at more than one point at index %d", idx)) + log.Printf("set and confirm intersect at more than one point at index %d", idx) + return nil, ErrSetConfirmSignupMismatch } intersectionSlice := intersection.ToSlice() passcode[idx] = intersectionSlice[0] @@ -104,7 +111,8 @@ func (s *UserSignSession) SetUserNKode(keySelection KeySelection) (IdxInterface, return 0 <= i && i < s.Kp.NumbOfKeys }) if !validKeySelection { - return nil, errors.New(fmt.Sprintf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1)) + log.Printf("one or key selection is out of range 0-%d", s.Kp.NumbOfKeys-1) + return nil, ErrKeyIndexOutOfRange } s.SetKeySelection = keySelection @@ -134,7 +142,7 @@ func (s *UserSignSession) getSelectedKeyVals(keySelections KeySelection, userInt func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, error) { if kp.IsDispersable() { - return nil, errors.New("keypad is dispersable, can't use signupInterface") + return nil, ErrKeypadIsNotDispersible } err := baseUserInterface.RandomShuffle() if err != nil { diff --git a/main_test.go b/main_test.go index f7460ee..931a937 100644 --- a/main_test.go +++ b/main_test.go @@ -78,9 +78,9 @@ func TestApi(t *testing.T) { var jwtTokens core.AuthenticationTokens testApiPost(t, base+core.Login, loginBody, &jwtTokens) - refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken) + refreshClaims, err := core.ParseRegisteredClaimToken(jwtTokens.RefreshToken) assert.Equal(t, refreshClaims.Subject, userEmail) - accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken) + accessClaims, err := core.ParseRegisteredClaimToken(jwtTokens.AccessToken) assert.Equal(t, accessClaims.Subject, userEmail) renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} testApiPost(t, base+core.RenewAttributes, renewBody, nil) @@ -102,7 +102,7 @@ func TestApi(t *testing.T) { var refreshTokenResp core.RefreshTokenResp testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken) - accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken) + accessClaims, err = core.ParseRegisteredClaimToken(refreshTokenResp.AccessToken) assert.NoError(t, err) assert.Equal(t, accessClaims.Subject, userEmail) } diff --git a/util/util.go b/util/util.go index 208fa43..68712d7 100644 --- a/util/util.go +++ b/util/util.go @@ -6,8 +6,8 @@ import ( "encoding/binary" "encoding/hex" "errors" - "fmt" "go-nkode/hashset" + "log" "math/big" r "math/rand" "time" @@ -17,12 +17,25 @@ type ShuffleTypes interface { []int | int | []uint64 | uint64 } -// fisherYatesShuffle shuffles a slice of bytes in place using the Fisher-Yates algorithm. +var ( + ErrFisherYatesShuffle = errors.New("unable to shuffle array") + ErrRandomBytes = errors.New("random bytes error") + ErrRandNonRepeatingUint64 = errors.New("list length must be less than 2^32") + ErrParseHexString = errors.New("parse hex string error") + ErrMatrixTranspose = errors.New("matrix cannot be nil or empty") + ErrListToMatrixNotDivisible = errors.New("list to matrix not possible") + ErrElementNotInArray = errors.New("element not in array") + ErrDecodeBase64Str = errors.New("decode base64 err") + ErrRandNonRepeatingInt = errors.New("list length must be less than 2^31") + ErrXorLengthMismatch = errors.New("xor length mismatch") +) + func fisherYatesShuffle[T ShuffleTypes](b *[]T) error { for i := len(*b) - 1; i > 0; i-- { bigJ, err := rand.Int(rand.Reader, big.NewInt(int64(i+1))) if err != nil { - return err + log.Print("fisher yates shuffle error: ", err) + return ErrFisherYatesShuffle } j := bigJ.Int64() (*b)[i], (*b)[j] = (*b)[j], (*b)[i] @@ -38,7 +51,8 @@ func RandomBytes(n int) ([]byte, error) { b := make([]byte, n) _, err := rand.Read(b) if err != nil { - return nil, err + log.Print("error in random bytes: ", err) + return nil, ErrRandomBytes } return b, nil } @@ -72,7 +86,7 @@ func GenerateRandomInt() (int, error) { func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) { if listLen > int(1)<<32 { - return nil, errors.New("list length must be less than 2^32") + return nil, ErrRandNonRepeatingUint64 } listSet := make(hashset.Set[uint64]) for { @@ -92,7 +106,7 @@ func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) { func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) { if listLen > int(1)<<31 { - return nil, errors.New("list length must be less than 2^31") + return nil, ErrRandNonRepeatingInt } listSet := make(hashset.Set[int]) for { @@ -112,7 +126,8 @@ func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) { func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) { if len(l0) != len(l1) { - return nil, errors.New(fmt.Sprintf("list len mismatch %d, %d", len(l0), len(l1))) + log.Printf("list len mismatch %d, %d", len(l0), len(l1)) + return nil, ErrXorLengthMismatch } xorList := make([]uint64, len(l0)) @@ -131,7 +146,8 @@ func EncodeBase64Str(data []uint64) string { func DecodeBase64Str(encoded string) ([]uint64, error) { dataBytes, err := base64.StdEncoding.DecodeString(encoded) if err != nil { - return nil, err + log.Print("error decoding base64 str: ", err) + return nil, ErrDecodeBase64Str } data := ByteArrToUint64Arr(dataBytes) return data, nil @@ -179,13 +195,13 @@ func ByteArrToIntArr(byteArr []byte) []int { return intArr } -func IndexOf[T uint64 | int](arr []T, el T) int { +func IndexOf[T uint64 | int](arr []T, el T) (int, error) { for idx, val := range arr { if val == el { - return idx + return idx, nil } } - return -1 + return -1, ErrElementNotInArray } func IdentityArray(arrLen int) []int { @@ -199,7 +215,8 @@ func IdentityArray(arrLen int) []int { func ListToMatrix(listArr []int, numbCols int) ([][]int, error) { if len(listArr)%numbCols != 0 { - panic(fmt.Sprintf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols)) + log.Printf("Array is not evenly divisible by number of columns: %d mod %d = %d", len(listArr), numbCols, len(listArr)%numbCols) + return nil, ErrListToMatrixNotDivisible } numbRows := len(listArr) / numbCols matrix := make([][]int, numbRows) @@ -213,7 +230,8 @@ func ListToMatrix(listArr []int, numbCols int) ([][]int, error) { func MatrixTranspose(matrix [][]int) ([][]int, error) { if matrix == nil || len(matrix) == 0 { - return nil, fmt.Errorf("matrix cannot be nil or empty") + log.Print("can't transpose nil or zero len matrix") + return nil, ErrMatrixTranspose } rows := len(matrix) @@ -222,7 +240,8 @@ func MatrixTranspose(matrix [][]int) ([][]int, error) { // Check if the matrix is not rectangular for _, row := range matrix { if len(row) != cols { - return nil, fmt.Errorf("all rows must have the same number of columns") + log.Print("all rows must have the same number of columns") + return nil, ErrMatrixTranspose } } @@ -267,7 +286,8 @@ func ParseHexString(hexStr string) ([]byte, error) { // Decode the hex string into bytes bytes, err := hex.DecodeString(hexStr) if err != nil { - return nil, err + log.Print("parse hex string err: ", err) + return nil, ErrParseHexString } return bytes, nil }