From fe06a95c9861fdeb5ca7d7b628d5706b339a342b Mon Sep 17 00:00:00 2001 From: Donovan Date: Tue, 27 Aug 2024 19:27:52 -0500 Subject: [PATCH] implement and test sql db accessor --- core/model/customer.go | 91 ++++++ core/model/customer_attributes.go | 97 +++++++ core/{api => model}/nkode_handler.go | 42 +-- core/model/nkode_policy.go | 3 + core/{nkode => model}/test_helper.go | 5 +- core/{nkode => model}/user.go | 24 +- core/{nkode => model}/user_cipher_keys.go | 27 +- core/{nkode => model}/user_interface.go | 9 +- core/{nkode => model}/user_test.go | 21 +- core/nkode/common.go | 14 +- core/nkode/customer.go | 200 ------------- core/nkode/customer_attributes.go | 83 ------ core/nkode/customer_test.go | 14 +- core/nkode/db_accessor.go | 11 +- core/nkode/in_memory_db.go | 31 +- core/nkode/nkode_api.go | 4 +- core/nkode/nkode_api_test.go | 30 +- core/nkode/sqlite_db.go | 338 ++++++++++++++++++++++ core/nkode/sqlite_db_test.go | 44 +++ core/nkode/user_signup_session.go | 10 +- main.go | 11 +- main_test.go | 9 +- util/util.go | 22 ++ util/util_test.go | 8 + 24 files changed, 745 insertions(+), 403 deletions(-) create mode 100644 core/model/customer.go create mode 100644 core/model/customer_attributes.go rename core/{api => model}/nkode_handler.go (88%) rename core/{nkode => model}/test_helper.go (86%) rename core/{nkode => model}/user.go (70%) rename core/{nkode => model}/user_cipher_keys.go (88%) rename core/{nkode => model}/user_interface.go (96%) rename core/{nkode => model}/user_test.go (87%) delete mode 100644 core/nkode/customer.go delete mode 100644 core/nkode/customer_attributes.go create mode 100644 core/nkode/sqlite_db_test.go diff --git a/core/model/customer.go b/core/model/customer.go new file mode 100644 index 0000000..c7bb4f4 --- /dev/null +++ b/core/model/customer.go @@ -0,0 +1,91 @@ +package m + +import ( + "errors" + "fmt" + "github.com/google/uuid" + "go-nkode/hashset" + py "go-nkode/py-builtin" + "go-nkode/util" +) + +type Customer struct { + Id CustomerId + NKodePolicy NKodePolicy + Attributes CustomerAttributes +} + +func NewCustomer(nkodePolicy NKodePolicy) (*Customer, error) { + customerAttrs, err := NewCustomerAttributes() + if err != nil { + return nil, err + } + customer := Customer{ + Id: CustomerId(uuid.New()), + NKodePolicy: nkodePolicy, + Attributes: *customerAttrs, + } + + return &customer, nil +} + +func (c *Customer) IsValidNKode(kp KeypadDimension, passcodeAttrIdx []int) error { + nkodeLen := len(passcodeAttrIdx) + if nkodeLen < c.NKodePolicy.MinNkodeLen { + return errors.New(fmt.Sprintf("NKode length %d is too short. Minimum nKode length is %d", nkodeLen, c.NKodePolicy.MinNkodeLen)) + } + + validIdx := py.All[int](passcodeAttrIdx, func(i int) bool { + return i >= 0 && i < kp.TotalAttrs() + }) + + if !validIdx { + return errors.New(fmt.Sprintf("One or more idx out of range 0-%d in IsValidNKode", kp.TotalAttrs()-1)) + } + passcodeSetVals := make(hashset.Set[uint64]) + passcodeAttrVals := make(hashset.Set[uint64]) + attrVals, err := c.Attributes.AttrValsForKp(kp) + if err != nil { + return err + } + for idx := 0; idx < nkodeLen; idx++ { + attrVal := attrVals[passcodeAttrIdx[idx]] + setVal, err := c.Attributes.GetAttrSetVal(attrVal, kp) + if err != nil { + return err + } + passcodeSetVals.Add(setVal) + passcodeAttrVals.Add(attrVal) + } + + if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets { + return errors.New(fmt.Sprintf("passcode has two few distinct sets min %d, has %d", c.NKodePolicy.DistinctSets, passcodeSetVals.Size())) + } + + if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes { + return errors.New(fmt.Sprintf("passcode has two few distinct attributes min %d, has %d", c.NKodePolicy.DistinctAttributes, passcodeAttrVals.Size())) + } + return nil +} + +func (c *Customer) RenewKeys() ([]uint64, []uint64) { + oldAttrs := make([]uint64, len(c.Attributes.AttrVals)) + oldSets := make([]uint64, len(c.Attributes.SetVals)) + + copy(oldAttrs, c.Attributes.AttrVals) + copy(oldSets, c.Attributes.SetVals) + + err := c.Attributes.Renew() + if err != nil { + panic(err) + } + attrsXor, err := util.XorLists(oldAttrs, c.Attributes.AttrVals) + if err != nil { + panic(err) + } + setXor, err := util.XorLists(oldSets, c.Attributes.SetVals) + if err != nil { + panic(err) + } + return setXor, attrsXor +} diff --git a/core/model/customer_attributes.go b/core/model/customer_attributes.go new file mode 100644 index 0000000..7b24a81 --- /dev/null +++ b/core/model/customer_attributes.go @@ -0,0 +1,97 @@ +package m + +import ( + "errors" + "fmt" + "go-nkode/util" +) + +type CustomerAttributes struct { + AttrVals []uint64 + SetVals []uint64 +} + +func NewCustomerAttributes() (*CustomerAttributes, error) { + attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) + if errAttr != nil { + return nil, errAttr + } + setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) + if errSet != nil { + return nil, errSet + } + + customerAttrs := CustomerAttributes{ + AttrVals: attrVals, + SetVals: setVals, + } + return &customerAttrs, nil +} + +func NewCustomerAttributesFromBytes(attrBytes []byte, setBytes []byte) CustomerAttributes { + return CustomerAttributes{ + AttrVals: util.ByteArrToUint64Arr(attrBytes), + SetVals: util.ByteArrToUint64Arr(setBytes), + } +} + +func (c *CustomerAttributes) Renew() error { + attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(KeypadMax.TotalAttrs()) + if errAttr != nil { + return errAttr + } + setVals, errSet := util.GenerateRandomNonRepeatingUint64(KeypadMax.AttrsPerKey) + if errSet != nil { + return errSet + } + c.AttrVals = attrVals + c.SetVals = setVals + return nil +} + +func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) int { + // TODO: should this be mapped instead? + return util.IndexOf[uint64](c.AttrVals, attrVal) +} + +func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) { + // TODO: should this be mapped instead? + idx := util.IndexOf[uint64](c.SetVals, setVal) + if idx == -1 { + return -1, errors.New(fmt.Sprintf("Set Val %d is invalid", setVal)) + } + return idx, nil +} + +func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad KeypadDimension) (uint64, error) { + indexOfAttr := c.IndexOfAttr(attrVal) + if indexOfAttr == -1 { + return 0, errors.New(fmt.Sprintf("No attribute %d", attrVal)) + } + setIdx := indexOfAttr % userKeypad.AttrsPerKey + return c.SetVals[setIdx], nil +} + +func (c *CustomerAttributes) AttrValsForKp(userKp KeypadDimension) ([]uint64, error) { + err := userKp.IsValidKeypadDimension() + if err != nil { + return nil, err + } + return c.AttrVals[:userKp.TotalAttrs()], nil +} + +func (c *CustomerAttributes) SetValsForKp(userKp KeypadDimension) ([]uint64, error) { + err := userKp.IsValidKeypadDimension() + if err != nil { + return nil, err + } + return c.SetVals[:userKp.AttrsPerKey], nil +} + +func (c *CustomerAttributes) AttrBytes() []byte { + return util.Uint64ArrToByteArr(c.AttrVals) +} + +func (c *CustomerAttributes) SetBytes() []byte { + return util.Uint64ArrToByteArr(c.SetVals) +} diff --git a/core/api/nkode_handler.go b/core/model/nkode_handler.go similarity index 88% rename from core/api/nkode_handler.go rename to core/model/nkode_handler.go index 9956562..ebf344a 100644 --- a/core/api/nkode_handler.go +++ b/core/model/nkode_handler.go @@ -1,31 +1,31 @@ -package api +package m import ( "encoding/json" - m "go-nkode/core/model" + "go-nkode/core/api" "log" "net/http" ) type NKodeHandler struct { - Api m.NKodeAPIInterface + Api NKodeAPIInterface } func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case CreateNewCustomer: + case api.CreateNewCustomer: h.CreateNewCustomerHandler(w, r) - case GenerateSignupInterface: + case api.GenerateSignupInterface: h.GenerateSignupInterfaceHandler(w, r) - case SetNKode: + case api.SetNKode: h.SetNKodeHandler(w, r) - case ConfirmNKode: + case api.ConfirmNKode: h.ConfirmNKodeHandler(w, r) - case GetLoginInterface: + case api.GetLoginInterface: h.GetLoginInterfaceHandler(w, r) - case Login: + case api.Login: h.LoginHandler(w, r) - case RenewAttributes: + case api.RenewAttributes: h.RenewAttributesHandler(w, r) default: w.WriteHeader(http.StatusNotFound) @@ -39,7 +39,7 @@ func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.R methodNotAllowed(w) return } - var customerPost m.NewCustomerPost + var customerPost NewCustomerPost err := decodeJson(w, r, &customerPost) if err != nil { internalServerErrorHandler(w) @@ -52,7 +52,7 @@ func (h *NKodeHandler) CreateNewCustomerHandler(w http.ResponseWriter, r *http.R log.Fatal(err) return } - respBody := m.CreateNewCustomerResp{ + respBody := CreateNewCustomerResp{ CustomerId: *customerId, } respBytes, err := json.Marshal(respBody) @@ -77,14 +77,14 @@ func (h *NKodeHandler) GenerateSignupInterfaceHandler(w http.ResponseWriter, r * return } - var signupPost m.GenerateSignupInterfacePost + var signupPost GenerateSignupInterfacePost err := decodeJson(w, r, &signupPost) if err != nil { internalServerErrorHandler(w) log.Fatal(err) return } - resp, err := h.Api.GenerateSignupInterface(signupPost.CustomerId, m.KeypadDefault) + resp, err := h.Api.GenerateSignupInterface(signupPost.CustomerId, KeypadDefault) if err != nil { internalServerErrorHandler(w) log.Fatal(err) @@ -112,7 +112,7 @@ func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { methodNotAllowed(w) return } - var setNKodePost m.SetNKodePost + var setNKodePost SetNKodePost err := decodeJson(w, r, &setNKodePost) if err != nil { internalServerErrorHandler(w) @@ -125,7 +125,7 @@ func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { log.Fatal(err) return } - respBody := m.SetNKodeResp{UserInterface: confirmInterface} + respBody := SetNKodeResp{UserInterface: confirmInterface} respBytes, err := json.Marshal(respBody) if err != nil { @@ -150,7 +150,7 @@ func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Reques return } - var confirmNKodePost m.ConfirmNKodePost + var confirmNKodePost ConfirmNKodePost err := decodeJson(w, r, &confirmNKodePost) if err != nil { internalServerErrorHandler(w) @@ -173,7 +173,7 @@ func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.R methodNotAllowed(w) return } - var loginInterfacePost m.GetLoginInterfacePost + var loginInterfacePost GetLoginInterfacePost err := decodeJson(w, r, &loginInterfacePost) if err != nil { internalServerErrorHandler(w) @@ -187,7 +187,7 @@ func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.R return } - respBody := m.GetLoginInterfaceResp{UserInterface: loginInterface} + respBody := GetLoginInterfaceResp{UserInterface: loginInterface} respBytes, err := json.Marshal(respBody) if err != nil { internalServerErrorHandler(w) @@ -210,7 +210,7 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { methodNotAllowed(w) return } - var loginPost m.LoginPost + var loginPost LoginPost err := decodeJson(w, r, &loginPost) if err != nil { internalServerErrorHandler(w) @@ -232,7 +232,7 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req methodNotAllowed(w) return } - var renewAttributesPost m.RenewAttributesPost + var renewAttributesPost RenewAttributesPost err := decodeJson(w, r, &renewAttributesPost) if err != nil { diff --git a/core/model/nkode_policy.go b/core/model/nkode_policy.go index 35b168a..1f3030a 100644 --- a/core/model/nkode_policy.go +++ b/core/model/nkode_policy.go @@ -31,5 +31,8 @@ func (p *NKodePolicy) ValidLength(nkodeLen int) error { if nkodeLen < p.MinNkodeLen || nkodeLen > p.MaxNkodeLen { return InvalidNKodeLen } + // TODO: validate Max > Min + // Validate lockout + // Add Lockout To User return nil } diff --git a/core/nkode/test_helper.go b/core/model/test_helper.go similarity index 86% rename from core/nkode/test_helper.go rename to core/model/test_helper.go index c82b943..1688c3d 100644 --- a/core/nkode/test_helper.go +++ b/core/model/test_helper.go @@ -1,13 +1,12 @@ -package nkode +package m import ( "errors" "fmt" - m "go-nkode/core/model" "go-nkode/util" ) -func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize m.KeypadDimension) ([]int, error) { +func SelectKeyByAttrIdx(interfaceUser []int, passcodeIdxs []int, keypadSize KeypadDimension) ([]int, error) { selectedKeys := make([]int, len(passcodeIdxs)) for idx := range passcodeIdxs { attrIdx := util.IndexOf[int](interfaceUser, passcodeIdxs[idx]) diff --git a/core/nkode/user.go b/core/model/user.go similarity index 70% rename from core/nkode/user.go rename to core/model/user.go index 8138297..f69282d 100644 --- a/core/nkode/user.go +++ b/core/model/user.go @@ -1,16 +1,15 @@ -package nkode +package m import ( - m "go-nkode/core/model" "go-nkode/util" ) type User struct { - Id m.UserId - CustomerId m.CustomerId - Username m.Username - EncipheredPasscode m.EncipheredNKode - Kp m.KeypadDimension + Id UserId + CustomerId CustomerId + Username Username + EncipheredPasscode EncipheredNKode + Kp KeypadDimension UserKeys UserCipherKeys Interface UserInterface Renew bool @@ -23,19 +22,16 @@ func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) func (u *User) RenewKeys(setXor []uint64, attrXor []uint64) error { u.Renew = true var err error - u.UserKeys.SetKey, err = util.XorLists(setXor, u.UserKeys.SetKey) + u.UserKeys.SetKey, err = util.XorLists(setXor[:u.Kp.AttrsPerKey], u.UserKeys.SetKey) if err != nil { panic(err) } - u.UserKeys.AlphaKey, err = util.XorLists(attrXor, u.UserKeys.AlphaKey) - if err != nil { - panic(err) - } - return nil + u.UserKeys.AlphaKey, err = util.XorLists(attrXor[:u.Kp.TotalAttrs()], u.UserKeys.AlphaKey) + return err } func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error { - setVals, err := customerAttributes.SetVals(u.Kp) + setVals, err := customerAttributes.SetValsForKp(u.Kp) newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.UserKeys.MaxNKodeLen) if err != nil { return err diff --git a/core/nkode/user_cipher_keys.go b/core/model/user_cipher_keys.go similarity index 88% rename from core/nkode/user_cipher_keys.go rename to core/model/user_cipher_keys.go index 1e0c5ea..2434bae 100644 --- a/core/nkode/user_cipher_keys.go +++ b/core/model/user_cipher_keys.go @@ -1,27 +1,23 @@ -package nkode +package m import ( "crypto/sha256" "errors" - "github.com/google/uuid" - m "go-nkode/core/model" "go-nkode/util" "golang.org/x/crypto/bcrypt" ) type UserCipherKeys struct { - Id uuid.UUID - UserId m.UserId AlphaKey []uint64 SetKey []uint64 PassKey []uint64 MaskKey []uint64 Salt []byte MaxNKodeLen int - kp *m.KeypadDimension + Kp *KeypadDimension } -func NewUserCipherKeys(kp *m.KeypadDimension, setVals []uint64, maxNKodeLen int) (*UserCipherKeys, error) { +func NewUserCipherKeys(kp *KeypadDimension, setVals []uint64, maxNKodeLen int) (*UserCipherKeys, error) { err := kp.IsValidKeypadDimension() if err != nil { return nil, err @@ -40,14 +36,13 @@ func NewUserCipherKeys(kp *m.KeypadDimension, setVals []uint64, maxNKodeLen int) maskKey, _ := util.GenerateRandomNonRepeatingUint64(maxNKodeLen) salt, _ := util.RandomBytes(10) userCipherKeys := UserCipherKeys{ - Id: uuid.New(), AlphaKey: alphakey, PassKey: passKey, MaskKey: maskKey, SetKey: setKey, Salt: salt, MaxNKodeLen: maxNKodeLen, - kp: kp, + Kp: kp, } return &userCipherKeys, nil } @@ -127,8 +122,8 @@ func (u *UserCipherKeys) hashPasscode(passcodeDigest []byte) ([]byte, error) { } return hashedPassword, nil } -func (u *UserCipherKeys) EncipherMask(passcodeSet []uint64, customerAttrs CustomerAttributes, userKp m.KeypadDimension) (string, error) { - setVals, err := customerAttrs.SetVals(userKp) +func (u *UserCipherKeys) EncipherMask(passcodeSet []uint64, customerAttrs CustomerAttributes, userKp KeypadDimension) (string, error) { + setVals, err := customerAttrs.SetValsForKp(userKp) if err != nil { return "", err } @@ -174,8 +169,8 @@ func (u *UserCipherKeys) DecipherMask(mask string, setVals []uint64, passcodeLen return passcodeSet, nil } -func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs CustomerAttributes) (*m.EncipheredNKode, error) { - attrVals, err := customerAttrs.AttrVals(*u.kp) +func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs CustomerAttributes) (*EncipheredNKode, error) { + attrVals, err := customerAttrs.AttrValsForKp(*u.Kp) code, err := u.EncipherSaltHashCode(passcodeAttrIdx, attrVals) if err != nil { return nil, err @@ -184,13 +179,13 @@ func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs Cust for idx := range passcodeSet { passcodeAttr := attrVals[passcodeAttrIdx[idx]] - passcodeSet[idx], err = customerAttrs.GetAttrSetVal(passcodeAttr, *u.kp) + passcodeSet[idx], err = customerAttrs.GetAttrSetVal(passcodeAttr, *u.Kp) if err != nil { return nil, err } } - mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.kp) - encipheredCode := m.EncipheredNKode{ + mask, err := u.EncipherMask(passcodeSet, customerAttrs, *u.Kp) + encipheredCode := EncipheredNKode{ Code: code, Mask: mask, } diff --git a/core/nkode/user_interface.go b/core/model/user_interface.go similarity index 96% rename from core/nkode/user_interface.go rename to core/model/user_interface.go index 2470889..bcc46fc 100644 --- a/core/nkode/user_interface.go +++ b/core/model/user_interface.go @@ -1,19 +1,18 @@ -package nkode +package m import ( "errors" "fmt" - m "go-nkode/core/model" "go-nkode/hashset" "go-nkode/util" ) type UserInterface struct { - IdxInterface m.IdxInterface - Kp *m.KeypadDimension + IdxInterface IdxInterface + Kp *KeypadDimension } -func NewUserInterface(kp *m.KeypadDimension) (*UserInterface, error) { +func NewUserInterface(kp *KeypadDimension) (*UserInterface, error) { idxInterface := util.IdentityArray(kp.TotalAttrs()) userInterface := UserInterface{ IdxInterface: idxInterface, diff --git a/core/nkode/user_test.go b/core/model/user_test.go similarity index 87% rename from core/nkode/user_test.go rename to core/model/user_test.go index 470a275..bfeb51a 100644 --- a/core/nkode/user_test.go +++ b/core/model/user_test.go @@ -1,20 +1,19 @@ -package nkode +package m import ( "github.com/stretchr/testify/assert" - m "go-nkode/core/model" py "go-nkode/py-builtin" "testing" ) func TestUserCipherKeys_EncipherSaltHashCode(t *testing.T) { - kp := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} + kp := KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} maxNKodeLen := 10 customerAttrs, err := NewCustomerAttributes() assert.NoError(t, err) - setVals, err := customerAttrs.SetVals(kp) + setVals, err := customerAttrs.SetValsForKp(kp) assert.NoError(t, err) - attrVals, err := customerAttrs.AttrVals(kp) + attrVals, err := customerAttrs.AttrValsForKp(kp) assert.NoError(t, err) newUser, err := NewUserCipherKeys(&kp, setVals, maxNKodeLen) assert.NoError(t, err) @@ -26,14 +25,14 @@ func TestUserCipherKeys_EncipherSaltHashCode(t *testing.T) { } func TestUserCipherKeys_EncipherDecipherMask(t *testing.T) { - kp := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} + kp := KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} maxNKodeLen := 10 customerAttrs, err := NewCustomerAttributes() assert.NoError(t, err) - setVals, err := customerAttrs.SetVals(kp) + setVals, err := customerAttrs.SetValsForKp(kp) assert.NoError(t, err) - attrVals, err := customerAttrs.AttrVals(kp) + attrVals, err := customerAttrs.AttrValsForKp(kp) assert.NoError(t, err) newUser, err := NewUserCipherKeys(&kp, setVals, maxNKodeLen) assert.NoError(t, err) @@ -56,7 +55,7 @@ func TestUserCipherKeys_EncipherDecipherMask(t *testing.T) { } func TestUserInterface_RandomShuffle(t *testing.T) { - kp := m.KeypadDimension{ + kp := KeypadDimension{ AttrsPerKey: 10, NumbOfKeys: 8, } @@ -81,7 +80,7 @@ func TestUserInterface_RandomShuffle(t *testing.T) { func TestUserInterface_DisperseInterface(t *testing.T) { for idx := 0; idx < 10000; idx++ { - kp := m.KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} + kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} userInterface, err := NewUserInterface(&kp) assert.NoError(t, err) @@ -100,7 +99,7 @@ func TestUserInterface_DisperseInterface(t *testing.T) { } func TestUserInterface_PartialInterfaceShuffle(t *testing.T) { - kp := m.KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} + kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} userInterface, err := NewUserInterface(&kp) assert.NoError(t, err) preShuffle := userInterface.IdxInterface diff --git a/core/nkode/common.go b/core/nkode/common.go index 99f58e5..5b06bf6 100644 --- a/core/nkode/common.go +++ b/core/nkode/common.go @@ -9,7 +9,7 @@ import ( var KeyIndexOutOfRange error = errors.New("one or more keys is out of range") -func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, error) { +func ValidKeyEntry(user m.User, customer m.Customer, selectedKeys []int) ([]int, error) { validKeys := py.All[int](selectedKeys, func(idx int) bool { return 0 <= idx && idx < user.Kp.NumbOfKeys }) @@ -24,7 +24,7 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err return nil, err } - setVals, err := customer.Attributes.SetVals(user.Kp) + setVals, err := customer.Attributes.SetValsForKp(user.Kp) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err if err != nil { panic(err) } - attrVals, err := customer.Attributes.AttrVals(user.Kp) + attrVals, err := customer.Attributes.AttrValsForKp(user.Kp) if err != nil { panic(err) } @@ -63,12 +63,12 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err return presumedAttrIdxVals, nil } -func NewUser(customer Customer, username m.Username, passcodeIdx []int, ui UserInterface, kp m.KeypadDimension) (*User, error) { - setVals, err := customer.Attributes.SetVals(kp) +func NewUser(customer m.Customer, username m.Username, passcodeIdx []int, ui m.UserInterface, kp m.KeypadDimension) (*m.User, error) { + setVals, err := customer.Attributes.SetValsForKp(kp) if err != nil { return nil, err } - newKeys, err := NewUserCipherKeys(&kp, setVals, customer.NKodePolicy.MaxNkodeLen) + newKeys, err := m.NewUserCipherKeys(&kp, setVals, customer.NKodePolicy.MaxNkodeLen) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func NewUser(customer Customer, username m.Username, passcodeIdx []int, ui UserI if err != nil { return nil, err } - newUser := User{ + newUser := m.User{ Id: m.UserId(uuid.New()), Username: username, EncipheredPasscode: *encipheredNKode, diff --git a/core/nkode/customer.go b/core/nkode/customer.go deleted file mode 100644 index 6dbc5dc..0000000 --- a/core/nkode/customer.go +++ /dev/null @@ -1,200 +0,0 @@ -package nkode - -import ( - "errors" - "fmt" - "github.com/google/uuid" - m "go-nkode/core/model" - "go-nkode/hashset" - py "go-nkode/py-builtin" - "go-nkode/util" -) - -type Customer struct { - Id m.CustomerId - NKodePolicy m.NKodePolicy - Attributes CustomerAttributes -} - -func NewCustomer(nkodePolicy m.NKodePolicy) (*Customer, error) { - customerAttrs, err := NewCustomerAttributes() - if err != nil { - return nil, err - } - customer := Customer{ - Id: m.CustomerId(uuid.New()), - NKodePolicy: nkodePolicy, - Attributes: *customerAttrs, - } - - return &customer, nil -} - -//func (c *Customer) AddNewUser(username m.Username, passcodeIdx []int, ui UserInterface, kp m.KeypadDimension) error { -// _, exists := c.Users[username] -// if exists { -// return errors.New(fmt.Sprintf("User %s already exists for customer %+v exists", username, c.Id)) -// } -// setVals, err := c.Attributes.SetVals(kp) -// if err != nil { -// return err -// } -// newKeys, err := NewUserCipherKeys(&kp, setVals, c.NKodePolicy.MaxNkodeLen) -// if err != nil { -// return err -// } -// encipheredNKode, err := newKeys.EncipherNKode(passcodeIdx, c.Attributes) -// if err != nil { -// return err -// } -// newUser := User{ -// Id: m.UserId(uuid.New()), -// Username: username, -// EncipheredPasscode: *encipheredNKode, -// UserKeys: *newKeys, -// Interface: ui, -// Kp: kp, -// } -// c.Users[username] = newUser -// return nil -//} - -//func (c *Customer) ValidKeyEntry(username m.Username, selectedKeys []int) ([]int, error) { -// user, exists := c.Users[username] -// if !exists { -// return nil, errors.New(fmt.Sprintf("user %s does not exist for customer %+v", username, c.Id)) -// } -// -// validKeys := py.All[int](selectedKeys, func(idx int) bool { -// return 0 <= idx && idx < user.Kp.NumbOfKeys -// }) -// if !validKeys { -// return nil, errors.New(fmt.Sprintf("one or more keys not in range 0-%d", user.Kp.NumbOfKeys-1)) -// } -// presumedAttrIdxVals, err := c.getPresumedAttributeIdxVals(user, selectedKeys) -// if err != nil { -// return nil, err -// } -// err = c.IsValidNKode(user.Kp, presumedAttrIdxVals) -// if err != nil { -// return nil, err -// } -// attrVals, err := c.Attributes.AttrVals(user.Kp) -// if err != nil { -// return nil, err -// } -// err = user.UserKeys.ValidPassword(user.EncipheredPasscode.Code, presumedAttrIdxVals, attrVals) -// if err != nil { -// return nil, err -// } -// -// return presumedAttrIdxVals, nil -//} - -//func (c *Customer) getPresumedAttributeIdxVals(user User, selectedKeys []int) ([]int, error) { -// -// passcodeLen := len(selectedKeys) -// if passcodeLen < c.NKodePolicy.MinNkodeLen || passcodeLen > c.NKodePolicy.MaxNkodeLen { -// return nil, errors.New(fmt.Sprintf("Invalid passcode length of %d. Passcode length must be in range %d-%d", passcodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.MaxNkodeLen)) -// } -// -// setVals, err := c.Attributes.SetVals(user.Kp) -// if err != nil { -// return nil, err -// } -// passcodeSetVals, err := user.DecipherMask(setVals, passcodeLen) -// if err != nil { -// return nil, err -// } -// presumedAttrIdxVals := make([]int, passcodeLen) -// -// for idx := range presumedAttrIdxVals { -// keyNumb := selectedKeys[idx] -// setIdx, err := c.Attributes.IndexOfSet(passcodeSetVals[idx]) -// if err != nil { -// return nil, err -// } -// selectedAttrIdx, err := user.Interface.GetAttrIdxByKeyNumbSetIdx(setIdx, keyNumb) -// if err != nil { -// return nil, err -// } -// presumedAttrIdxVals[idx] = selectedAttrIdx -// } -// return presumedAttrIdxVals, nil -//} - -func (c *Customer) IsValidNKode(kp m.KeypadDimension, passcodeAttrIdx []int) error { - nkodeLen := len(passcodeAttrIdx) - if nkodeLen < c.NKodePolicy.MinNkodeLen { - return errors.New(fmt.Sprintf("NKode length %d is too short. Minimum nKode length is %d", nkodeLen, c.NKodePolicy.MinNkodeLen)) - } - - validIdx := py.All[int](passcodeAttrIdx, func(i int) bool { - return i >= 0 && i < kp.TotalAttrs() - }) - - if !validIdx { - return errors.New(fmt.Sprintf("One or more idx out of range 0-%d in IsValidNKode", kp.TotalAttrs()-1)) - } - passcodeSetVals := make(hashset.Set[uint64]) - passcodeAttrVals := make(hashset.Set[uint64]) - attrVals, err := c.Attributes.AttrVals(kp) - if err != nil { - return err - } - for idx := 0; idx < nkodeLen; idx++ { - attrVal := attrVals[passcodeAttrIdx[idx]] - setVal, err := c.Attributes.GetAttrSetVal(attrVal, kp) - if err != nil { - return err - } - passcodeSetVals.Add(setVal) - passcodeAttrVals.Add(attrVal) - } - - if passcodeSetVals.Size() < c.NKodePolicy.DistinctSets { - return errors.New(fmt.Sprintf("passcode has two few distinct sets min %d, has %d", c.NKodePolicy.DistinctSets, passcodeSetVals.Size())) - } - - if passcodeAttrVals.Size() < c.NKodePolicy.DistinctAttributes { - return errors.New(fmt.Sprintf("passcode has two few distinct attributes min %d, has %d", c.NKodePolicy.DistinctAttributes, passcodeAttrVals.Size())) - } - return nil -} - -func (c *Customer) RenewKeys() ([]uint64, []uint64) { - oldAttrs := make([]uint64, m.KeypadMax.TotalAttrs()) - oldSets := make([]uint64, m.KeypadMax.AttrsPerKey) - allAttrVals, err := c.Attributes.AttrVals(m.KeypadMax) - if err != nil { - panic(err) - } - allSetVals, err := c.Attributes.AttrVals(m.KeypadMax) - if err != nil { - panic(err) - } - copy(oldAttrs, allAttrVals) - copy(oldSets, allSetVals) - - err = c.Attributes.Renew() - if err != nil { - panic(err) - } - allAttrVals, err = c.Attributes.AttrVals(m.KeypadMax) - if err != nil { - panic(err) - } - allSetVals, err = c.Attributes.SetVals(m.KeypadMax) - if err != nil { - panic(err) - } - attrsXor, err := util.XorLists(oldAttrs, allAttrVals) - if err != nil { - panic(err) - } - setXor, err := util.XorLists(oldSets, allSetVals) - if err != nil { - panic(err) - } - return setXor, attrsXor -} diff --git a/core/nkode/customer_attributes.go b/core/nkode/customer_attributes.go deleted file mode 100644 index a0a44d5..0000000 --- a/core/nkode/customer_attributes.go +++ /dev/null @@ -1,83 +0,0 @@ -package nkode - -import ( - "errors" - "fmt" - "go-nkode/core/model" - "go-nkode/util" -) - -type CustomerAttributes struct { - attrVals []uint64 - setVals []uint64 -} - -func NewCustomerAttributes() (*CustomerAttributes, error) { - attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(m.KeypadMax.TotalAttrs()) - if errAttr != nil { - return nil, errAttr - } - setVals, errSet := util.GenerateRandomNonRepeatingUint64(m.KeypadMax.AttrsPerKey) - if errSet != nil { - return nil, errSet - } - - customerAttrs := CustomerAttributes{ - attrVals: attrVals, - setVals: setVals, - } - return &customerAttrs, nil -} - -func (c *CustomerAttributes) Renew() error { - attrVals, errAttr := util.GenerateRandomNonRepeatingUint64(m.KeypadMax.TotalAttrs()) - if errAttr != nil { - return errAttr - } - setVals, errSet := util.GenerateRandomNonRepeatingUint64(m.KeypadMax.AttrsPerKey) - if errSet != nil { - return errSet - } - c.attrVals = attrVals - c.setVals = setVals - return nil -} - -func (c *CustomerAttributes) IndexOfAttr(attrVal uint64) int { - // TODO: should this be mapped instead? - return util.IndexOf[uint64](c.attrVals, attrVal) -} - -func (c *CustomerAttributes) IndexOfSet(setVal uint64) (int, error) { - // TODO: should this be mapped instead? - idx := util.IndexOf[uint64](c.setVals, setVal) - if idx == -1 { - return -1, errors.New(fmt.Sprintf("Set Val %d is invalid", setVal)) - } - return idx, nil -} - -func (c *CustomerAttributes) GetAttrSetVal(attrVal uint64, userKeypad m.KeypadDimension) (uint64, error) { - indexOfAttr := c.IndexOfAttr(attrVal) - if indexOfAttr == -1 { - return 0, errors.New(fmt.Sprintf("No attribute %d", attrVal)) - } - setIdx := indexOfAttr % userKeypad.AttrsPerKey - return c.setVals[setIdx], nil -} - -func (c *CustomerAttributes) AttrVals(userKp m.KeypadDimension) ([]uint64, error) { - err := userKp.IsValidKeypadDimension() - if err != nil { - return nil, err - } - return c.attrVals[:userKp.TotalAttrs()], nil -} - -func (c *CustomerAttributes) SetVals(userKp m.KeypadDimension) ([]uint64, error) { - err := userKp.IsValidKeypadDimension() - if err != nil { - return nil, err - } - return c.setVals[:userKp.AttrsPerKey], nil -} diff --git a/core/nkode/customer_test.go b/core/nkode/customer_test.go index 8c4d0d1..2d5be1f 100644 --- a/core/nkode/customer_test.go +++ b/core/nkode/customer_test.go @@ -2,7 +2,7 @@ package nkode import ( "github.com/stretchr/testify/assert" - m "go-nkode/core/model" + "go-nkode/core/model" "testing" ) @@ -14,16 +14,16 @@ func TestCustomer(t *testing.T) { func testNewCustomerAttributes(t *testing.T) { // keypad := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 5} - _, nil := NewCustomerAttributes() + _, nil := m.NewCustomerAttributes() assert.NoError(t, nil) } func testCustomerValidKeyEntry(t *testing.T) { kp := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 9} nkodePolicy := m.NewDefaultNKodePolicy() - customer, err := NewCustomer(nkodePolicy) + customer, err := m.NewCustomer(nkodePolicy) assert.NoError(t, err) - newUserInterface, err := NewUserInterface(&kp) + newUserInterface, err := m.NewUserInterface(&kp) assert.NoError(t, err) username := m.Username("testing123") passcodeIdx := []int{0, 1, 2, 3} @@ -31,7 +31,7 @@ func testCustomerValidKeyEntry(t *testing.T) { assert.NoError(t, err) userLoginInterface, err := user.GetLoginInterface() assert.NoError(t, err) - selectedKeys, err := SelectKeyByAttrIdx(userLoginInterface, passcodeIdx, kp) + selectedKeys, err := m.SelectKeyByAttrIdx(userLoginInterface, passcodeIdx, kp) assert.NoError(t, err) validatedPasscode, err := ValidKeyEntry(*user, *customer, selectedKeys) assert.NoError(t, err) @@ -44,9 +44,9 @@ func testCustomerValidKeyEntry(t *testing.T) { func testCustomerIsValidNKode(t *testing.T) { kp := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 7} nkodePolicy := m.NewDefaultNKodePolicy() - customer, err := NewCustomer(nkodePolicy) + customer, err := m.NewCustomer(nkodePolicy) assert.NoError(t, err) - newUserInterface, err := NewUserInterface(&kp) + newUserInterface, err := m.NewUserInterface(&kp) assert.NoError(t, err) username := m.Username("testing123") passcodeIdx := []int{0, 1, 2, 3} diff --git a/core/nkode/db_accessor.go b/core/nkode/db_accessor.go index 220152b..adc8123 100644 --- a/core/nkode/db_accessor.go +++ b/core/nkode/db_accessor.go @@ -5,10 +5,11 @@ import ( ) type DbAccessor interface { - GetCustomer(m.CustomerId) (*Customer, error) - GetUser(m.Username, m.CustomerId) (*User, error) - WriteNewCustomer(Customer) error - WriteNewUser(User) error - UpdateUserInterface(m.UserId, UserInterface) error + GetCustomer(m.CustomerId) (*m.Customer, error) + GetUser(m.Username, m.CustomerId) (*m.User, error) + WriteNewCustomer(m.Customer) error + WriteNewUser(m.User) error + UpdateUserInterface(m.UserId, m.UserInterface) error Renew(m.CustomerId) error + RefreshUser(m.User, []int, m.CustomerAttributes) error } diff --git a/core/nkode/in_memory_db.go b/core/nkode/in_memory_db.go index 1eae1ba..943c6bb 100644 --- a/core/nkode/in_memory_db.go +++ b/core/nkode/in_memory_db.go @@ -7,20 +7,20 @@ import ( ) type InMemoryDb struct { - Customers map[m.CustomerId]Customer - Users map[m.UserId]User + Customers map[m.CustomerId]m.Customer + Users map[m.UserId]m.User userIdMap map[string]m.UserId } func NewInMemoryDb() InMemoryDb { return InMemoryDb{ - Customers: make(map[m.CustomerId]Customer), - Users: make(map[m.UserId]User), + Customers: make(map[m.CustomerId]m.Customer), + Users: make(map[m.UserId]m.User), userIdMap: make(map[string]m.UserId), } } -func (db *InMemoryDb) GetCustomer(id m.CustomerId) (*Customer, error) { +func (db *InMemoryDb) GetCustomer(id m.CustomerId) (*m.Customer, error) { customer, exists := db.Customers[id] if !exists { return nil, errors.New(fmt.Sprintf("customer %s dne", customer.Id)) @@ -28,7 +28,7 @@ func (db *InMemoryDb) GetCustomer(id m.CustomerId) (*Customer, error) { return &customer, nil } -func (db *InMemoryDb) GetUser(username m.Username, customerId m.CustomerId) (*User, error) { +func (db *InMemoryDb) GetUser(username m.Username, customerId m.CustomerId) (*m.User, error) { key := userIdKey(customerId, username) userId, exists := db.userIdMap[key] if !exists { @@ -41,7 +41,7 @@ func (db *InMemoryDb) GetUser(username m.Username, customerId m.CustomerId) (*Us return &user, nil } -func (db *InMemoryDb) WriteNewCustomer(customer Customer) error { +func (db *InMemoryDb) WriteNewCustomer(customer m.Customer) error { _, exists := db.Customers[customer.Id] if exists { @@ -51,7 +51,7 @@ func (db *InMemoryDb) WriteNewCustomer(customer Customer) error { return nil } -func (db *InMemoryDb) WriteNewUser(user User) error { +func (db *InMemoryDb) WriteNewUser(user m.User) error { _, exists := db.Customers[user.CustomerId] if !exists { return errors.New(fmt.Sprintf("can't add user %s to customer %s: customer dne", user.Username, user.CustomerId)) @@ -67,7 +67,7 @@ func (db *InMemoryDb) WriteNewUser(user User) error { return nil } -func (db *InMemoryDb) UpdateUserInterface(userId m.UserId, ui UserInterface) error { +func (db *InMemoryDb) UpdateUserInterface(userId m.UserId, ui m.UserInterface) error { user, exists := db.Users[userId] if !exists { return errors.New(fmt.Sprintf("can't update user %s, dne", user.Id)) @@ -82,18 +82,29 @@ func (db *InMemoryDb) Renew(id m.CustomerId) error { return errors.New(fmt.Sprintf("customer %s does not exist", id)) } setXor, attrsXor := customer.RenewKeys() + db.Customers[id] = customer var err error for _, user := range db.Users { if user.CustomerId == id { - err = user.RenewKeys(setXor[:user.Kp.AttrsPerKey], attrsXor[:user.Kp.TotalAttrs()]) + err = user.RenewKeys(setXor, attrsXor) if err != nil { panic(err) } + db.Users[user.Id] = user } } return nil } +func (db *InMemoryDb) RefreshUser(user m.User, passocode []int, customerAttr m.CustomerAttributes) error { + err := user.RefreshPasscode(passocode, customerAttr) + if err != nil { + return err + } + db.Users[user.Id] = user + return nil +} + func userIdKey(customerId m.CustomerId, username m.Username) string { key := fmt.Sprintf("%s:%s", customerId, username) return key diff --git a/core/nkode/nkode_api.go b/core/nkode/nkode_api.go index 7b6a253..be15b23 100644 --- a/core/nkode/nkode_api.go +++ b/core/nkode/nkode_api.go @@ -19,7 +19,7 @@ func NewNKodeAPI(db DbAccessor) NKodeAPI { } func (n *NKodeAPI) CreateNewCustomer(nkodePolicy m.NKodePolicy) (*m.CustomerId, error) { - newCustomer, err := NewCustomer(nkodePolicy) + newCustomer, err := m.NewCustomer(nkodePolicy) if err != nil { return nil, err } @@ -118,7 +118,7 @@ func (n *NKodeAPI) Login(customerId m.CustomerId, username m.Username, keySelect return err } if user.Renew { - err = user.RefreshPasscode(passcode, customer.Attributes) + err = n.Db.RefreshUser(*user, passcode, customer.Attributes) if err != nil { return err } diff --git a/core/nkode/nkode_api_test.go b/core/nkode/nkode_api_test.go index aa72782..7252393 100644 --- a/core/nkode/nkode_api_test.go +++ b/core/nkode/nkode_api_test.go @@ -3,17 +3,34 @@ package nkode import ( "github.com/stretchr/testify/assert" m "go-nkode/core/model" + "os" "testing" ) func TestNKodeAPI(t *testing.T) { + //db1 := NewInMemoryDb() + //1testNKodeAPI(t, &db1) + + dbFile := "test.db" + db2, err := NewSqliteDB(dbFile) + assert.NoError(t, err) + testNKodeAPI(t, db2) + if _, err := os.Stat(dbFile); err == nil { + err = os.Remove(dbFile) + assert.NoError(t, err) + } else { + assert.NoError(t, err) + } +} + +func testNKodeAPI(t *testing.T, db DbAccessor) { + for idx := 0; idx < 10; idx++ { - db := NewInMemoryDb() username := m.Username("test_username") passcodeLen := 4 nkodePolicy := m.NewDefaultNKodePolicy() keypadSize := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} - nkodeApi := NewNKodeAPI(&db) + nkodeApi := NewNKodeAPI(db) customerId, err := nkodeApi.CreateNewCustomer(nkodePolicy) assert.NoError(t, err) signupResponse, err := nkodeApi.GenerateSignupInterface(*customerId, keypadSize) @@ -22,18 +39,18 @@ func TestNKodeAPI(t *testing.T) { sessionId := signupResponse.SessionId keypadSize = m.KeypadDimension{AttrsPerKey: 8, NumbOfKeys: 8} userPasscode := setInterface[:passcodeLen] - setKeySelect, err := SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) + setKeySelect, err := m.SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) assert.NoError(t, err) confirmInterface, err := nkodeApi.SetNKode(username, *customerId, sessionId, setKeySelect) assert.NoError(t, err) - confirmKeySelect, err := SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) + confirmKeySelect, err := m.SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) err = nkodeApi.ConfirmNKode(*customerId, sessionId, confirmKeySelect) assert.NoError(t, err) keypadSize = m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} loginInterface, err := nkodeApi.GetLoginInterface(username, *customerId) assert.NoError(t, err) - loginKeySelection, err := SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) + loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) assert.NoError(t, err) err = nkodeApi.Login(*customerId, username, loginKeySelection) assert.NoError(t, err) @@ -43,9 +60,10 @@ func TestNKodeAPI(t *testing.T) { loginInterface, err = nkodeApi.GetLoginInterface(username, *customerId) assert.NoError(t, err) - loginKeySelection, err = SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) + loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) assert.NoError(t, err) err = nkodeApi.Login(*customerId, username, loginKeySelection) assert.NoError(t, err) + } } diff --git a/core/nkode/sqlite_db.go b/core/nkode/sqlite_db.go index 324c377..1df26de 100644 --- a/core/nkode/sqlite_db.go +++ b/core/nkode/sqlite_db.go @@ -1 +1,339 @@ package nkode + +import ( + "database/sql" + "errors" + "fmt" + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver + m "go-nkode/core/model" + "go-nkode/util" + "log" +) + +type SqliteDB struct { + path string +} + +func NewSqliteDB(path string) (*SqliteDB, error) { + sqldb := SqliteDB{path: path} + err := sqldb.NewTables() + + return &sqldb, err +} + +func (d *SqliteDB) NewTables() error { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + createTables := ` +PRAGMA foreign_keys = ON; + +CREATE TABLE IF NOT EXISTS customer ( + id TEXT NOT NULL PRIMARY KEY, + max_nkode_len INTEGER NOT NULL, + min_nkode_len INTEGER NOT NULL, + distinct_sets INTEGER NOT NULL, + distinct_attributes INTEGER NOT NULL, + lock_out INTEGER NOT NULL, + expiration INTEGER NOT NULL, + attribute_values BLOB NOT NULL, + set_values BLOB NOT NULL +); + +CREATE TABLE IF NOT EXISTS user ( + id TEXT NOT NULL PRIMARY KEY, + username TEXT NOT NULL, + renew INT NOT NULL, + customer_id TEXT NOT NULL, + +-- Enciphered Passcode + code TEXT NOT NULL, + mask TEXT NOT NULL, + +-- Keypad Dimensions + attributes_per_key INT NOT NULL, + number_of_keys INT NOT NULL, + +-- User Keys + alpha_key BLOB NOT NULL, + set_key BLOB NOT NULL, + pass_key BLOB NOT NULL, + mask_key BLOB NOT NULL, + salt BLOB NOT NULL, + max_nkode_len INT NOT NULL, + +-- User Interface + idx_interface BLOB NOT NULL, + + + FOREIGN KEY (customer_id) REFERENCES customers(id), + UNIQUE(customer_id, username) +); +` + + _, err = db.Exec(createTables) + if err != nil { + return err + } + + return nil +} + +func (d *SqliteDB) WriteNewCustomer(c m.Customer) error { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + log.Fatal(err) + } + defer db.Close() + insertCustomer := ` +INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values) +VALUES (?,?,?,?,?,?,?,?,?) +` + _, err = db.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) + + return err +} + +func (d *SqliteDB) WriteNewUser(u m.User) error { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + log.Fatal(err) + } + defer db.Close() + insertUser := ` +INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface) +VALUES (?,?, ?,?,?,?,?,?,?,?,?,?,?,?,?) +` + var renew int + if u.Renew { + renew = 1 + } else { + renew = 0 + } + _, err = db.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.UserKeys.AlphaKey), util.Uint64ArrToByteArr(u.UserKeys.SetKey), util.Uint64ArrToByteArr(u.UserKeys.PassKey), util.Uint64ArrToByteArr(u.UserKeys.MaskKey), u.UserKeys.Salt, u.UserKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface)) + + return err +} + +func (d *SqliteDB) GetCustomer(id m.CustomerId) (*m.Customer, error) { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return nil, err + } + defer db.Close() + selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?` + rows, err := db.Query(selectCustomer, uuid.UUID(id)) + + if !rows.Next() { + return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err())) + } + + var maxNKodeLen int + var minNKodeLen int + var distinctSets int + var distinctAttributes int + var lockOut int + var expiration int + var attributeValues []byte + var setValues []byte + err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues) + if err != nil { + return nil, err + } + + if rows.Next() { + return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id)) + } + customer := m.Customer{ + Id: id, + NKodePolicy: m.NKodePolicy{ + MaxNkodeLen: maxNKodeLen, + MinNkodeLen: minNKodeLen, + DistinctSets: distinctSets, + DistinctAttributes: distinctAttributes, + LockOut: lockOut, + Expiration: expiration, + }, + Attributes: m.NewCustomerAttributesFromBytes(attributeValues, setValues), + } + + return &customer, nil +} + +func (d *SqliteDB) GetUser(username m.Username, customerId m.CustomerId) (*m.User, error) { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return nil, err + } + defer db.Close() + + userSelect := ` +SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface FROM user +WHERE user.username = ? AND user.customer_id = ? +` + rows, err := db.Query(userSelect, string(username), uuid.UUID(customerId).String()) + if !rows.Next() { + return nil, errors.New(fmt.Sprintf("no new rows for user %s of customer %s", string(username), uuid.UUID(customerId).String())) + } + var id string + var renewVal int + var code string + var mask string + var attrsPerKey int + var numbOfKeys int + var alphaKey []byte + var setKey []byte + var passKey []byte + var maskKey []byte + var salt []byte + var maxNKodeLen int + var idxInterface []byte + + err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface) + if rows.Next() { + return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) + } + + userId, err := uuid.Parse(id) + if err != nil { + return nil, err + } + var renew bool + if renewVal == 0 { + renew = false + } else { + renew = true + } + + user := m.User{ + Id: m.UserId(userId), + CustomerId: customerId, + Username: username, + EncipheredPasscode: m.EncipheredNKode{ + Code: code, + Mask: mask, + }, + Kp: m.KeypadDimension{ + AttrsPerKey: attrsPerKey, + NumbOfKeys: numbOfKeys, + }, + UserKeys: m.UserCipherKeys{ + AlphaKey: util.ByteArrToUint64Arr(alphaKey), + SetKey: util.ByteArrToUint64Arr(setKey), + PassKey: util.ByteArrToUint64Arr(passKey), + MaskKey: util.ByteArrToUint64Arr(maskKey), + Salt: salt, + MaxNKodeLen: maxNKodeLen, + Kp: nil, + }, + Interface: m.UserInterface{ + IdxInterface: util.ByteArrToIntArr(idxInterface), + Kp: nil, + }, + Renew: renew, + } + user.Interface.Kp = &user.Kp + user.UserKeys.Kp = &user.Kp + + return &user, nil +} + +func (d *SqliteDB) UpdateUserInterface(id m.UserId, ui m.UserInterface) error { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return err + } + defer db.Close() + updateUserInterface := ` +UPDATE user SET idx_interface = ? WHERE id = ? +` + _, err = db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()) + + return err +} + +func (d *SqliteDB) Renew(id m.CustomerId) error { + customer, err := d.GetCustomer(id) + if err != nil { + return err + } + setXor, attrXor := customer.RenewKeys() + renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} + + renewExec := ` +BEGIN TRANSACTION; + +UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; +` + + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return err + } + defer db.Close() + + userQuery := ` +SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ? +` + rows, err := db.Query(userQuery, uuid.UUID(id).String()) + for rows.Next() { + var userId string + var alphaBytes []byte + var setBytes []byte + var attrsPerKey int + var numbOfKeys int + err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) + if err != nil { + return err + } + user := m.User{ + Id: m.UserId{}, + CustomerId: m.CustomerId{}, + Username: "", + EncipheredPasscode: m.EncipheredNKode{}, + Kp: m.KeypadDimension{ + AttrsPerKey: attrsPerKey, + NumbOfKeys: numbOfKeys, + }, + UserKeys: m.UserCipherKeys{ + AlphaKey: util.ByteArrToUint64Arr(alphaBytes), + SetKey: util.ByteArrToUint64Arr(setBytes), + }, + Interface: m.UserInterface{}, + Renew: false, + } + err = user.RenewKeys(setXor, attrXor) + if err != nil { + return err + } + renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;" + renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.UserKeys.AlphaKey), util.Uint64ArrToByteArr(user.UserKeys.SetKey), 1, userId) + } + renewExec += ` +COMMIT; +` + + _, err = db.Exec(renewExec, renewArgs...) + return err +} + +func (d *SqliteDB) RefreshUser(user m.User, passcodeIdx []int, customerAttr m.CustomerAttributes) error { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return err + } + defer db.Close() + err = user.RefreshPasscode(passcodeIdx, customerAttr) + if err != nil { + return err + } + updateUser := ` +UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; +` + _, err = db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.UserKeys.AlphaKey), util.Uint64ArrToByteArr(user.UserKeys.SetKey), util.Uint64ArrToByteArr(user.UserKeys.PassKey), util.Uint64ArrToByteArr(user.UserKeys.MaskKey), user.UserKeys.Salt, uuid.UUID(user.Id).String()) + return err +} diff --git a/core/nkode/sqlite_db_test.go b/core/nkode/sqlite_db_test.go new file mode 100644 index 0000000..552b093 --- /dev/null +++ b/core/nkode/sqlite_db_test.go @@ -0,0 +1,44 @@ +package nkode + +import ( + "github.com/stretchr/testify/assert" + m "go-nkode/core/model" + "os" + "testing" +) + +func TestNewSqliteDB(t *testing.T) { + dbFile := "test.db" + db, err := NewSqliteDB(dbFile) + assert.NoError(t, err) + nkode_policy := m.NewDefaultNKodePolicy() + customerOrig, err := m.NewCustomer(nkode_policy) + assert.NoError(t, err) + err = db.WriteNewCustomer(*customerOrig) + assert.NoError(t, err) + customer, err := db.GetCustomer(customerOrig.Id) + assert.NoError(t, err) + assert.Equal(t, customerOrig, customer) + username := m.Username("test_user") + kp := m.KeypadDefault + passcodeIdx := []int{0, 1, 2, 3} + ui, err := m.NewUserInterface(&kp) + assert.NoError(t, err) + userOrig, err := NewUser(*customer, username, passcodeIdx, *ui, kp) + assert.NoError(t, err) + err = db.WriteNewUser(*userOrig) + assert.NoError(t, err) + user, err := db.GetUser(username, customer.Id) + assert.NoError(t, err) + assert.Equal(t, userOrig, user) + + err = db.Renew(customer.Id) + assert.NoError(t, err) + + if _, err := os.Stat(dbFile); err == nil { + err = os.Remove(dbFile) + assert.NoError(t, err) + } else { + assert.NoError(t, err) + } +} diff --git a/core/nkode/user_signup_session.go b/core/nkode/user_signup_session.go index 8d96870..a54c59e 100644 --- a/core/nkode/user_signup_session.go +++ b/core/nkode/user_signup_session.go @@ -13,7 +13,7 @@ import ( type UserSignSession struct { Id m.SessionId CustomerId m.CustomerId - LoginUserInterface UserInterface + LoginUserInterface m.UserInterface Kp m.KeypadDimension SetIdxInterface m.IdxInterface ConfirmIdxInterface m.IdxInterface @@ -23,7 +23,7 @@ type UserSignSession struct { } func NewSignupSession(kp m.KeypadDimension, customerId m.CustomerId) (*UserSignSession, error) { - loginInterface, err := NewUserInterface(&kp) + loginInterface, err := m.NewUserInterface(&kp) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func (s *UserSignSession) SetUserNKode(username m.Username, keySelection m.KeySe s.SetKeySelection = keySelection s.Username = username setKp := s.SignupKeypad() - setInterface := UserInterface{IdxInterface: s.SetIdxInterface, Kp: &setKp} + setInterface := m.UserInterface{IdxInterface: s.SetIdxInterface, Kp: &setKp} err := setInterface.DisperseInterface() if err != nil { return nil, err @@ -132,7 +132,7 @@ func (s *UserSignSession) getSelectedKeyVals(keySelections m.KeySelection, userI return keyVals, nil } -func signupInterface(baseUserInterface UserInterface, kp m.KeypadDimension) (*UserInterface, error) { +func signupInterface(baseUserInterface m.UserInterface, kp m.KeypadDimension) (*m.UserInterface, error) { if kp.IsDispersable() { return nil, errors.New("keypad is dispersable, can't use signupInterface") } @@ -158,7 +158,7 @@ func signupInterface(baseUserInterface UserInterface, kp m.KeypadDimension) (*Us if err != nil { return nil, err } - signupUserInterface := UserInterface{ + signupUserInterface := m.UserInterface{ IdxInterface: util.MatrixToList(attrSetView), Kp: &m.KeypadDimension{ AttrsPerKey: numbOfKeys, diff --git a/main.go b/main.go index 21ef794..2be02f0 100644 --- a/main.go +++ b/main.go @@ -3,15 +3,20 @@ package main import ( "fmt" "go-nkode/core/api" + "go-nkode/core/model" "go-nkode/core/nkode" "log" "net/http" ) func main() { - db := nkode.NewInMemoryDb() - nkodeApi := nkode.NewNKodeAPI(&db) - handler := api.NKodeHandler{Api: &nkodeApi} + //db := nkode.NewInMemoryDb() + db, err := nkode.NewSqliteDB("nkode.db") + if err != nil { + log.Fatal(err) + } + nkodeApi := nkode.NewNKodeAPI(db) + handler := m.NKodeHandler{Api: &nkodeApi} mux := http.NewServeMux() mux.Handle(api.CreateNewCustomer, &handler) mux.Handle(api.GenerateSignupInterface, &handler) diff --git a/main_test.go b/main_test.go index 04f1a36..b1eb2eb 100644 --- a/main_test.go +++ b/main_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" "go-nkode/core/api" m "go-nkode/core/model" - "go-nkode/core/nkode" "io" "net/http" "testing" @@ -30,7 +29,7 @@ func TestApi(t *testing.T) { setInterface := signupInterfaceResp.UserInterface userPasscode := setInterface[:passcodeLen] kp = m.KeypadDimension{NumbOfKeys: kp.NumbOfKeys, AttrsPerKey: kp.NumbOfKeys} - setKeySelection, err := nkode.SelectKeyByAttrIdx(setInterface, userPasscode, kp) + setKeySelection, err := m.SelectKeyByAttrIdx(setInterface, userPasscode, kp) assert.NoError(t, err) setNKodeBody := m.SetNKodePost{ CustomerId: customerResp.CustomerId, @@ -41,7 +40,7 @@ func TestApi(t *testing.T) { var setNKodeResp m.SetNKodeResp testApiCall(t, base+api.SetNKode, setNKodeBody, &setNKodeResp) confirmInterface := setNKodeResp.UserInterface - confirmKeySelection, err := nkode.SelectKeyByAttrIdx(confirmInterface, userPasscode, kp) + confirmKeySelection, err := m.SelectKeyByAttrIdx(confirmInterface, userPasscode, kp) assert.NoError(t, err) confirmNKodeBody := m.ConfirmNKodePost{ CustomerId: customerResp.CustomerId, @@ -59,7 +58,7 @@ func TestApi(t *testing.T) { testApiCall(t, base+api.GetLoginInterface, loginInterfaceBody, &loginInterfaceResp) kp = m.KeypadDefault - loginKeySelection, err := nkode.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) + loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) assert.NoError(t, err) loginBody := m.LoginPost{ CustomerId: customerResp.CustomerId, @@ -72,7 +71,7 @@ func TestApi(t *testing.T) { renewBody := m.RenewAttributesPost{CustomerId: customerResp.CustomerId} testApiCall(t, base+api.RenewAttributes, renewBody, nil) - loginKeySelection, err = nkode.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) + loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) assert.NoError(t, err) loginBody = m.LoginPost{ CustomerId: customerResp.CustomerId, diff --git a/util/util.go b/util/util.go index 87c165d..94702c7 100644 --- a/util/util.go +++ b/util/util.go @@ -117,6 +117,17 @@ func Uint64ArrToByteArr(intArr []uint64) []byte { return byteArr } +func IntArrToByteArr(intArr []int) []byte { + byteArr := make([]byte, len(intArr)*4) + for idx, val := range intArr { + uval := uint32(val) + startIdx := idx * 4 + endIdx := (idx + 1) * 4 + binary.LittleEndian.PutUint32(byteArr[startIdx:endIdx], uval) + } + return byteArr +} + func ByteArrToUint64Arr(byteArr []byte) []uint64 { intArr := make([]uint64, len(byteArr)/8) for idx := 0; idx < len(intArr); idx++ { @@ -127,6 +138,17 @@ func ByteArrToUint64Arr(byteArr []byte) []uint64 { return intArr } +func ByteArrToIntArr(byteArr []byte) []int { + intArr := make([]int, len(byteArr)/4) + for idx := 0; idx < len(intArr); idx++ { + startIdx := idx * 4 + endIdx := (idx + 1) * 4 + uval := binary.LittleEndian.Uint32(byteArr[startIdx:endIdx]) + intArr[idx] = int(uval) + } + return intArr +} + func IndexOf[T uint64 | int](arr []T, el T) int { for idx, val := range arr { if val == el { diff --git a/util/util_test.go b/util/util_test.go index 76bae04..fe8ab9a 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -44,3 +44,11 @@ func TestMatrixTranspose(t *testing.T) { assert.Equal(t, expectedFlatMat[idx], flatMat[idx]) } } + +func TestIntToByteAndBack(t *testing.T) { + origIntArr := []int{1, 2, 3, 4, 5} + byteArr := IntArrToByteArr(origIntArr) + intArr := ByteArrToIntArr(byteArr) + + assert.ElementsMatch(t, origIntArr, intArr) +}