diff --git a/.gitignore b/.gitignore index d5355b5..54b95df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ .idea tmp *.db +go-nkode +*.db-shm +*.db-wal diff --git a/core/api/endpoints.go b/core/api/endpoints.go index 48995e1..08b1a94 100644 --- a/core/api/endpoints.go +++ b/core/api/endpoints.go @@ -8,4 +8,5 @@ const ( GetLoginInterface = "/get-login-interface" Login = "/login" RenewAttributes = "/renew-attributes" + RandomSvgInterface = "/random-svg-interface" ) diff --git a/core/model/nkode_handler.go b/core/model/nkode_handler.go index 0596348..9686399 100644 --- a/core/model/nkode_handler.go +++ b/core/model/nkode_handler.go @@ -28,6 +28,8 @@ func (h *NKodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.LoginHandler(w, r) case api.RenewAttributes: h.RenewAttributesHandler(w, r) + case api.RandomSvgInterface: + h.RandomSvgInterfaceHandler(w, r) default: w.WriteHeader(http.StatusNotFound) _, err := w.Write([]byte("404 not found")) @@ -77,6 +79,7 @@ func (h *NKodeHandler) GenerateSignupInterfaceHandler(w http.ResponseWriter, r * methodNotAllowed(w) return } + log.Print("signup interface") var signupPost GenerateSignupInterfacePost err := decodeJson(w, r, &signupPost) @@ -95,7 +98,7 @@ func (h *NKodeHandler) GenerateSignupInterfaceHandler(w http.ResponseWriter, r * log.Println(err) return } - resp, err := h.Api.GenerateSignupInterface(signupPost.Username, signupPost.CustomerId, kp) + resp, err := h.Api.GenerateSignupInterface(signupPost.Username, CustomerId(signupPost.CustomerId), kp) if err != nil { internalServerErrorHandler(w) log.Println(err) @@ -130,7 +133,7 @@ func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { log.Println(err) return } - confirmInterface, err := h.Api.SetNKode(setNKodePost.CustomerId, setNKodePost.SessionId, setNKodePost.KeySelection) + confirmInterface, err := h.Api.SetNKode(CustomerId(setNKodePost.CustomerId), SessionId(setNKodePost.SessionId), setNKodePost.KeySelection) if err != nil { internalServerErrorHandler(w) log.Println(err) @@ -168,7 +171,7 @@ func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Reques log.Println(err) return } - err = h.Api.ConfirmNKode(confirmNKodePost.CustomerId, confirmNKodePost.SessionId, confirmNKodePost.KeySelection) + err = h.Api.ConfirmNKode(CustomerId(confirmNKodePost.CustomerId), SessionId(confirmNKodePost.SessionId), confirmNKodePost.KeySelection) if err != nil { internalServerErrorHandler(w) log.Println(err) @@ -191,15 +194,14 @@ func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.R log.Println(err) return } - loginInterface, err := h.Api.GetLoginInterface(loginInterfacePost.Username, loginInterfacePost.CustomerId) + loginInterface, err := h.Api.GetLoginInterface(loginInterfacePost.Username, CustomerId(loginInterfacePost.CustomerId)) if err != nil { internalServerErrorHandler(w) log.Println(err) return } - respBody := GetLoginInterfaceResp{UserInterface: loginInterface} - respBytes, err := json.Marshal(respBody) + respBytes, err := json.Marshal(loginInterface) if err != nil { internalServerErrorHandler(w) log.Println(err) @@ -228,7 +230,7 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { log.Println(err) return } - err = h.Api.Login(loginPost.CustomerId, loginPost.Username, loginPost.KeySelection) + err = h.Api.Login(CustomerId(loginPost.CustomerId), loginPost.Username, loginPost.KeySelection) if err != nil { internalServerErrorHandler(w) log.Println(err) @@ -252,7 +254,35 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req return } - err = h.Api.RenewAttributes(renewAttributesPost.CustomerId) + err = h.Api.RenewAttributes(CustomerId(renewAttributesPost.CustomerId)) + if err != nil { + internalServerErrorHandler(w) + log.Println(err) + return + } + + w.WriteHeader(http.StatusOK) +} + +func (h *NKodeHandler) RandomSvgInterfaceHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + } + svgs, err := h.Api.RandomSvgInterface() + if err != nil { + internalServerErrorHandler(w) + log.Println(err) + return + } + respBody := RandomSvgInterfaceResp{Svgs: svgs} + respBytes, err := json.Marshal(respBody) + + if err != nil { + internalServerErrorHandler(w) + log.Println(err) + return + } + _, err = w.Write(respBytes) if err != nil { internalServerErrorHandler(w) log.Println(err) diff --git a/core/model/type.go b/core/model/type.go index 8c0b4d1..c7f09bc 100644 --- a/core/model/type.go +++ b/core/model/type.go @@ -8,42 +8,46 @@ type SetNKodeResp struct { UserInterface []int `json:"user_interface"` } +type RandomSvgInterfaceResp struct { + Svgs []string `json:"svgs"` +} + type NewCustomerPost struct { NKodePolicy NKodePolicy `json:"nkode_policy"` } type GenerateSignupInterfacePost struct { - CustomerId CustomerId `json:"customer_id"` - AttrsPerKey int `json:"attrs_per_key"` - NumbOfKeys int `json:"numb_of_keys"` - Username Username `json:"username"` + CustomerId uuid.UUID `json:"customer_id"` + AttrsPerKey int `json:"attrs_per_key"` + NumbOfKeys int `json:"numb_of_keys"` + Username Username `json:"username"` } type SetNKodePost struct { - CustomerId CustomerId `json:"customer_id"` + CustomerId uuid.UUID `json:"customer_id"` KeySelection KeySelection `json:"key_selection"` - SessionId SessionId `json:"session_id"` + SessionId uuid.UUID `json:"session_id"` } type ConfirmNKodePost struct { - CustomerId CustomerId `json:"customer_id"` + CustomerId uuid.UUID `json:"customer_id"` KeySelection KeySelection `json:"key_selection"` - SessionId SessionId `json:"session_id"` + SessionId uuid.UUID `json:"session_id"` } type GetLoginInterfacePost struct { - Username Username `json:"username"` - CustomerId CustomerId `json:"customer_id"` + Username Username `json:"username"` + CustomerId uuid.UUID `json:"customer_id"` } type LoginPost struct { - CustomerId CustomerId `json:"customer_id"` + CustomerId uuid.UUID `json:"customer_id"` Username Username `json:"username"` KeySelection KeySelection `json:"key_selection"` } type RenewAttributesPost struct { - CustomerId CustomerId `json:"customer_id"` + CustomerId uuid.UUID `json:"customer_id"` } type CreateNewCustomerResp struct { @@ -51,12 +55,14 @@ type CreateNewCustomerResp struct { } type GenerateSignupInterfaceResp struct { - SessionId SessionId `json:"session_id"` - UserInterface IdxInterface `json:"user_interface"` + SessionId SessionId `json:"session_id"` + UserIdxInterface IdxInterface `json:"user_interface"` + SvgInterface []string `json:"svg_interface"` } type GetLoginInterfaceResp struct { - UserInterface IdxInterface `json:"user_interface"` + UserIdxInterface IdxInterface `json:"user_interface"` + SvgInterface []string `json:"svg_interface"` } type KeySelection []int @@ -66,15 +72,18 @@ type UserId uuid.UUID type Username string type IdxInterface []int +type SvgIdInterface []int type NKodeAPIInterface interface { CreateNewCustomer(NKodePolicy, *CustomerId) (*CustomerId, error) GenerateSignupInterface(Username, CustomerId, KeypadDimension) (*GenerateSignupInterfaceResp, error) SetNKode(CustomerId, SessionId, KeySelection) (IdxInterface, error) ConfirmNKode(CustomerId, SessionId, KeySelection) error - GetLoginInterface(Username, CustomerId) (IdxInterface, error) + GetLoginInterface(Username, CustomerId) (*GetLoginInterfaceResp, error) Login(CustomerId, Username, KeySelection) error RenewAttributes(CustomerId) error + RandomSvgInterface() ([]string, error) + GetSvgStringInterface(idInterface SvgIdInterface) ([]string, error) } type EncipheredNKode struct { diff --git a/core/model/user.go b/core/model/user.go index c569b93..fe9dff5 100644 --- a/core/model/user.go +++ b/core/model/user.go @@ -10,29 +10,29 @@ type User struct { Username Username EncipheredPasscode EncipheredNKode Kp KeypadDimension - UserKeys UserCipherKeys + CipherKeys UserCipherKeys Interface UserInterface Renew bool } func (u *User) DecipherMask(setVals []uint64, passcodeLen int) ([]uint64, error) { - return u.UserKeys.DecipherMask(u.EncipheredPasscode.Mask, setVals, passcodeLen) + return u.CipherKeys.DecipherMask(u.EncipheredPasscode.Mask, setVals, passcodeLen) } func (u *User) RenewKeys(setXor []uint64, attrXor []uint64) error { u.Renew = true var err error - u.UserKeys.SetKey, err = util.XorLists(setXor[:u.Kp.AttrsPerKey], u.UserKeys.SetKey) + u.CipherKeys.SetKey, err = util.XorLists(setXor[:u.Kp.AttrsPerKey], u.CipherKeys.SetKey) if err != nil { panic(err) } - u.UserKeys.AlphaKey, err = util.XorLists(attrXor[:u.Kp.TotalAttrs()], u.UserKeys.AlphaKey) + u.CipherKeys.AlphaKey, err = util.XorLists(attrXor[:u.Kp.TotalAttrs()], u.CipherKeys.AlphaKey) return err } func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error { setVals, err := customerAttributes.SetValsForKp(u.Kp) - newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.UserKeys.MaxNKodeLen) + newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen) if err != nil { return err } @@ -42,7 +42,7 @@ func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes Custome return err } - u.UserKeys = *newKeys + u.CipherKeys = *newKeys u.EncipheredPasscode = *encipheredPasscode u.Renew = false return nil diff --git a/core/model/user_interface.go b/core/model/user_interface.go index 1a078a5..1c0b064 100644 --- a/core/model/user_interface.go +++ b/core/model/user_interface.go @@ -9,13 +9,16 @@ import ( type UserInterface struct { IdxInterface IdxInterface + SvgId SvgIdInterface Kp *KeypadDimension } -func NewUserInterface(kp *KeypadDimension) (*UserInterface, error) { +func NewUserInterface(kp *KeypadDimension, svgId SvgIdInterface) (*UserInterface, error) { idxInterface := util.IdentityArray(kp.TotalAttrs()) + userInterface := UserInterface{ IdxInterface: idxInterface, + SvgId: svgId, Kp: kp, } err := userInterface.RandomShuffle() diff --git a/core/model/user_test.go b/core/model/user_test.go index 476ecba..aa9aeb6 100644 --- a/core/model/user_test.go +++ b/core/model/user_test.go @@ -59,7 +59,8 @@ func TestUserInterface_RandomShuffle(t *testing.T) { AttrsPerKey: 10, NumbOfKeys: 8, } - userInterface, err := NewUserInterface(&kp) + mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) userInterfaceCopy := make([]int, len(userInterface.IdxInterface)) copy(userInterfaceCopy, userInterface.IdxInterface) @@ -81,8 +82,8 @@ func TestUserInterface_DisperseInterface(t *testing.T) { for idx := 0; idx < 10000; idx++ { kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} - - userInterface, err := NewUserInterface(&kp) + mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) preDispersion, err := userInterface.AttributeAdjacencyGraph() assert.NoError(t, err) @@ -100,7 +101,8 @@ func TestUserInterface_DisperseInterface(t *testing.T) { func TestUserInterface_PartialInterfaceShuffle(t *testing.T) { kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} - userInterface, err := NewUserInterface(&kp) + mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) preShuffle := userInterface.IdxInterface err = userInterface.PartialInterfaceShuffle() diff --git a/core/nkode/common.go b/core/nkode/common.go index 5b06bf6..6b8bf8a 100644 --- a/core/nkode/common.go +++ b/core/nkode/common.go @@ -7,7 +7,7 @@ import ( py "go-nkode/py-builtin" ) -var KeyIndexOutOfRange error = errors.New("one or more keys is out of range") +var KeyIndexOutOfRange = errors.New("one or more keys is out of range") func ValidKeyEntry(user m.User, customer m.Customer, selectedKeys []int) ([]int, error) { validKeys := py.All[int](selectedKeys, func(idx int) bool { @@ -55,7 +55,7 @@ func ValidKeyEntry(user m.User, customer m.Customer, selectedKeys []int) ([]int, if err != nil { panic(err) } - err = user.UserKeys.ValidPassword(user.EncipheredPasscode.Code, presumedAttrIdxVals, attrVals) + err = user.CipherKeys.ValidPassword(user.EncipheredPasscode.Code, presumedAttrIdxVals, attrVals) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func NewUser(customer m.Customer, username m.Username, passcodeIdx []int, ui m.U Id: m.UserId(uuid.New()), Username: username, EncipheredPasscode: *encipheredNKode, - UserKeys: *newKeys, + CipherKeys: *newKeys, Interface: ui, Kp: kp, CustomerId: customer.Id, diff --git a/core/nkode/customer_test.go b/core/nkode/customer_test.go index 382685e..cb04a1c 100644 --- a/core/nkode/customer_test.go +++ b/core/nkode/customer_test.go @@ -23,11 +23,12 @@ func testCustomerValidKeyEntry(t *testing.T) { nkodePolicy := model.NewDefaultNKodePolicy() customer, err := model.NewCustomer(nkodePolicy) assert.NoError(t, err) - newUserInterface, err := model.NewUserInterface(&kp) + mockSvgInterface := make(model.SvgIdInterface, kp.TotalAttrs()) + userInterface, err := model.NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) username := model.Username("testing123") passcodeIdx := []int{0, 1, 2, 3} - user, err := NewUser(*customer, username, passcodeIdx, *newUserInterface, kp) + user, err := NewUser(*customer, username, passcodeIdx, *userInterface, kp) assert.NoError(t, err) userLoginInterface, err := user.GetLoginInterface() assert.NoError(t, err) @@ -46,11 +47,12 @@ func testCustomerIsValidNKode(t *testing.T) { nkodePolicy := model.NewDefaultNKodePolicy() customer, err := model.NewCustomer(nkodePolicy) assert.NoError(t, err) - newUserInterface, err := model.NewUserInterface(&kp) + mockSvgInterface := make(model.SvgIdInterface, kp.TotalAttrs()) + userInterface, err := model.NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) username := model.Username("testing123") passcodeIdx := []int{0, 1, 2, 3} - user, err := NewUser(*customer, username, passcodeIdx, *newUserInterface, kp) + user, err := NewUser(*customer, username, passcodeIdx, *userInterface, kp) assert.NoError(t, err) err = customer.IsValidNKode(user.Kp, passcodeIdx) assert.NoError(t, err) diff --git a/core/nkode/db_accessor.go b/core/nkode/db_accessor.go index c1e2279..a8a8da3 100644 --- a/core/nkode/db_accessor.go +++ b/core/nkode/db_accessor.go @@ -12,4 +12,7 @@ type DbAccessor interface { UpdateUserInterface(model.UserId, model.UserInterface) error Renew(model.CustomerId) error RefreshUser(model.User, []int, model.CustomerAttributes) error + RandomSvgInterface(model.KeypadDimension) ([]string, error) + RandomSvgIdxInterface(model.KeypadDimension) (model.SvgIdInterface, error) + GetSvgStringInterface(model.SvgIdInterface) ([]string, error) } diff --git a/core/nkode/in_memory_db.go b/core/nkode/in_memory_db.go index 943c6bb..1aed53c 100644 --- a/core/nkode/in_memory_db.go +++ b/core/nkode/in_memory_db.go @@ -105,6 +105,18 @@ func (db *InMemoryDb) RefreshUser(user m.User, passocode []int, customerAttr m.C return nil } +func (db *InMemoryDb) RandomSvgInterface(kp m.KeypadDimension) ([]string, error) { + return nil, errors.ErrUnsupported +} + +func (db *InMemoryDb) RandomSvgIdxInterface(kp m.KeypadDimension) (m.SvgIdInterface, error) { + return nil, errors.ErrUnsupported +} + +func (db InMemoryDb) GetSvgStringInterface(idxs m.SvgIdInterface) ([]string, error) { + return nil, errors.ErrUnsupported +} + func userIdKey(customerId m.CustomerId, username m.Username) string { key := fmt.Sprintf("%s:%s", customerId, username) return key diff --git a/core/nkode/nkode_api.go b/core/nkode/nkode_api.go index bfe7f5c..57d3e65 100644 --- a/core/nkode/nkode_api.go +++ b/core/nkode/nkode_api.go @@ -35,14 +35,23 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy m.NKodePolicy, id *m.CustomerId } func (n *NKodeAPI) GenerateSignupInterface(username m.Username, customerId m.CustomerId, kp m.KeypadDimension) (*m.GenerateSignupInterfaceResp, error) { - signupSession, err := NewSignupSession(username, kp, customerId) + svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) + if err != nil { + return nil, err + } + signupSession, err := NewSignupSession(username, kp, customerId, svgIdxInterface) if err != nil { return nil, err } n.SignupSessions[signupSession.Id] = *signupSession + svgInterface, err := n.Db.GetSvgStringInterface(signupSession.LoginUserInterface.SvgId) + if err != nil { + return nil, err + } resp := m.GenerateSignupInterfaceResp{ - UserInterface: signupSession.SetIdxInterface, - SessionId: signupSession.Id, + UserIdxInterface: signupSession.SetIdxInterface, + SvgInterface: svgInterface, + SessionId: signupSession.Id, } return &resp, nil } @@ -86,12 +95,12 @@ func (n *NKodeAPI) ConfirmNKode(customerId m.CustomerId, sessionId m.SessionId, if err != nil { return err } - delete(n.SignupSessions, session.Id) err = n.Db.WriteNewUser(*user) + delete(n.SignupSessions, session.Id) return err } -func (n *NKodeAPI) GetLoginInterface(username m.Username, customerId m.CustomerId) (m.IdxInterface, error) { +func (n *NKodeAPI) GetLoginInterface(username m.Username, customerId m.CustomerId) (*m.GetLoginInterfaceResp, error) { user, err := n.Db.GetUser(username, customerId) if err != nil { return nil, err @@ -104,7 +113,15 @@ func (n *NKodeAPI) GetLoginInterface(username m.Username, customerId m.CustomerI if err != nil { return nil, err } - return user.Interface.IdxInterface, nil + svgInterface, err := n.Db.GetSvgStringInterface(user.Interface.SvgId) + if err != nil { + return nil, err + } + resp := m.GetLoginInterfaceResp{ + UserIdxInterface: user.Interface.IdxInterface, + SvgInterface: svgInterface, + } + return &resp, nil } func (n *NKodeAPI) Login(customerId m.CustomerId, username m.Username, keySelection m.KeySelection) error { @@ -132,3 +149,11 @@ func (n *NKodeAPI) Login(customerId m.CustomerId, username m.Username, keySelect func (n *NKodeAPI) RenewAttributes(customerId m.CustomerId) error { return n.Db.Renew(customerId) } + +func (n *NKodeAPI) RandomSvgInterface() ([]string, error) { + return n.Db.RandomSvgInterface(m.KeypadMax) +} + +func (n *NKodeAPI) GetSvgStringInterface(svgId m.SvgIdInterface) ([]string, error) { + return n.Db.GetSvgStringInterface(svgId) +} diff --git a/core/nkode/nkode_api_test.go b/core/nkode/nkode_api_test.go index f4cd318..6f2e952 100644 --- a/core/nkode/nkode_api_test.go +++ b/core/nkode/nkode_api_test.go @@ -3,7 +3,7 @@ package nkode import ( "github.com/stretchr/testify/assert" m "go-nkode/core/model" - "os" + "go-nkode/util" "testing" ) @@ -11,22 +11,25 @@ func TestNKodeAPI(t *testing.T) { //db1 := NewInMemoryDb() //1testNKodeAPI(t, &db1) - dbFile := "test.db" - db2, err := NewSqliteDB(dbFile) - assert.NoError(t, err) + dbFile := "../../test.db" + + // sql_driver.MakeTables(dbFile) + db2 := NewSqliteDB(dbFile) + defer db2.CloseDb() testNKodeAPI(t, db2) - if _, err := os.Stat(dbFile); err == nil { - err = os.Remove(dbFile) - assert.NoError(t, err) - } else { - assert.NoError(t, err) - } + + // if _, err := os.Stat(dbFile); err == nil { + // err = os.Remove(dbFile) + // assert.NoError(t, err) + // } else { + // assert.NoError(t, err) + // } } func testNKodeAPI(t *testing.T, db DbAccessor) { - for idx := 0; idx < 10; idx++ { - username := m.Username("test_username") + for idx := 0; idx < 1; idx++ { + username := m.Username("test_username" + util.GenerateRandomString(12)) passcodeLen := 4 nkodePolicy := m.NewDefaultNKodePolicy() keypadSize := m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} @@ -35,7 +38,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { assert.NoError(t, err) signupResponse, err := nkodeApi.GenerateSignupInterface(username, *customerId, keypadSize) assert.NoError(t, err) - setInterface := signupResponse.UserInterface + setInterface := signupResponse.UserIdxInterface sessionId := signupResponse.SessionId keypadSize = m.KeypadDimension{AttrsPerKey: 8, NumbOfKeys: 8} userPasscode := setInterface[:passcodeLen] @@ -50,7 +53,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { keypadSize = m.KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 8} loginInterface, err := nkodeApi.GetLoginInterface(username, *customerId) assert.NoError(t, err) - loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) + loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) assert.NoError(t, err) err = nkodeApi.Login(*customerId, username, loginKeySelection) assert.NoError(t, err) @@ -60,7 +63,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { loginInterface, err = nkodeApi.GetLoginInterface(username, *customerId) assert.NoError(t, err) - loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterface, userPasscode, keypadSize) + loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) assert.NoError(t, err) err = nkodeApi.Login(*customerId, username, loginKeySelection) assert.NoError(t, err) diff --git a/core/nkode/sqlite_db.go b/core/nkode/sqlite_db.go index 1df26de..630e0fc 100644 --- a/core/nkode/sqlite_db.go +++ b/core/nkode/sqlite_db.go @@ -12,101 +12,70 @@ import ( ) type SqliteDB struct { - path string + db *sql.DB } -func NewSqliteDB(path string) (*SqliteDB, error) { - sqldb := SqliteDB{path: path} - err := sqldb.NewTables() +func NewSqliteDB(path string) *SqliteDB { + db, err := sql.Open("sqlite3", path) + if err != nil { + log.Fatal("database didn't open ", err) + } + sqldb := SqliteDB{db: db} - return &sqldb, err + return &sqldb } -func (d *SqliteDB) NewTables() error { - db, err := sql.Open("sqlite3", d.path) - if err != nil { - log.Fatal(err) +func (d *SqliteDB) CloseDb() { + if err := d.db.Close(); err != nil { + // If db.Close() returns an error, panic + panic(fmt.Sprintf("Failed to close the database: %v", err)) } - defer db.Close() - - createTables := ` -PRAGMA foreign_keys = ON; - -CREATE TABLE IF NOT EXISTS customer ( - id TEXT NOT NULL PRIMARY KEY, - max_nkode_len INTEGER NOT NULL, - min_nkode_len INTEGER NOT NULL, - distinct_sets INTEGER NOT NULL, - distinct_attributes INTEGER NOT NULL, - lock_out INTEGER NOT NULL, - expiration INTEGER NOT NULL, - attribute_values BLOB NOT NULL, - set_values BLOB NOT NULL -); - -CREATE TABLE IF NOT EXISTS user ( - id TEXT NOT NULL PRIMARY KEY, - username TEXT NOT NULL, - renew INT NOT NULL, - customer_id TEXT NOT NULL, - --- Enciphered Passcode - code TEXT NOT NULL, - mask TEXT NOT NULL, - --- Keypad Dimensions - attributes_per_key INT NOT NULL, - number_of_keys INT NOT NULL, - --- User Keys - alpha_key BLOB NOT NULL, - set_key BLOB NOT NULL, - pass_key BLOB NOT NULL, - mask_key BLOB NOT NULL, - salt BLOB NOT NULL, - max_nkode_len INT NOT NULL, - --- User Interface - idx_interface BLOB NOT NULL, - - - FOREIGN KEY (customer_id) REFERENCES customers(id), - UNIQUE(customer_id, username) -); -` - - _, err = db.Exec(createTables) - if err != nil { - return err - } - - return nil } func (d *SqliteDB) WriteNewCustomer(c m.Customer) error { - db, err := sql.Open("sqlite3", d.path) + tx, err := d.db.Begin() if err != nil { - log.Fatal(err) + return err } - defer db.Close() + defer func() { + if err != nil { + err = tx.Rollback() + if err != nil { + log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err)) + } + } + }() insertCustomer := ` INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values) VALUES (?,?,?,?,?,?,?,?,?) ` - _, err = db.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) - - return err + _, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + return nil } func (d *SqliteDB) WriteNewUser(u m.User) error { - db, err := sql.Open("sqlite3", d.path) + tx, err := d.db.Begin() if err != nil { - log.Fatal(err) + return err } - defer db.Close() + defer func() { + if err != nil { + err = tx.Rollback() + if err != nil { + log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) + } + } + }() insertUser := ` -INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface) -VALUES (?,?, ?,?,?,?,?,?,?,?,?,?,?,?,?) +INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) +VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` var renew int if u.Renew { @@ -114,19 +83,21 @@ VALUES (?,?, ?,?,?,?,?,?,?,?,?,?,?,?,?) } else { renew = 0 } - _, err = db.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.UserKeys.AlphaKey), util.Uint64ArrToByteArr(u.UserKeys.SetKey), util.Uint64ArrToByteArr(u.UserKeys.PassKey), util.Uint64ArrToByteArr(u.UserKeys.MaskKey), u.UserKeys.Salt, u.UserKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface)) + _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) - return err + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + return nil } func (d *SqliteDB) GetCustomer(id m.CustomerId) (*m.Customer, error) { - db, err := sql.Open("sqlite3", d.path) - if err != nil { - return nil, err - } - defer db.Close() selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?` - rows, err := db.Query(selectCustomer, uuid.UUID(id)) + rows, err := d.db.Query(selectCustomer, uuid.UUID(id)) if !rows.Next() { return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err())) @@ -165,17 +136,11 @@ func (d *SqliteDB) GetCustomer(id m.CustomerId) (*m.Customer, error) { } func (d *SqliteDB) GetUser(username m.Username, customerId m.CustomerId) (*m.User, error) { - db, err := sql.Open("sqlite3", d.path) - if err != nil { - return nil, err - } - defer db.Close() - userSelect := ` -SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface FROM user +SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user WHERE user.username = ? AND user.customer_id = ? ` - rows, err := db.Query(userSelect, string(username), uuid.UUID(customerId).String()) + rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String()) if !rows.Next() { return nil, errors.New(fmt.Sprintf("no new rows for user %s of customer %s", string(username), uuid.UUID(customerId).String())) } @@ -192,8 +157,9 @@ WHERE user.username = ? AND user.customer_id = ? var salt []byte var maxNKodeLen int var idxInterface []byte + var svgIdInterface []byte - err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface) + err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) if rows.Next() { return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) } @@ -221,7 +187,7 @@ WHERE user.username = ? AND user.customer_id = ? AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, - UserKeys: m.UserCipherKeys{ + CipherKeys: m.UserCipherKeys{ AlphaKey: util.ByteArrToUint64Arr(alphaKey), SetKey: util.ByteArrToUint64Arr(setKey), PassKey: util.ByteArrToUint64Arr(passKey), @@ -232,26 +198,22 @@ WHERE user.username = ? AND user.customer_id = ? }, Interface: m.UserInterface{ IdxInterface: util.ByteArrToIntArr(idxInterface), + SvgId: util.ByteArrToIntArr(svgIdInterface), Kp: nil, }, Renew: renew, } user.Interface.Kp = &user.Kp - user.UserKeys.Kp = &user.Kp + user.CipherKeys.Kp = &user.Kp return &user, nil } func (d *SqliteDB) UpdateUserInterface(id m.UserId, ui m.UserInterface) error { - db, err := sql.Open("sqlite3", d.path) - if err != nil { - return err - } - defer db.Close() updateUserInterface := ` UPDATE user SET idx_interface = ? WHERE id = ? ` - _, err = db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()) + _, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()) return err } @@ -263,23 +225,17 @@ func (d *SqliteDB) Renew(id m.CustomerId) error { } setXor, attrXor := customer.RenewKeys() renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} - + // TODO: replace with tx renewExec := ` BEGIN TRANSACTION; UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; ` - db, err := sql.Open("sqlite3", d.path) - if err != nil { - return err - } - defer db.Close() - userQuery := ` SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ? ` - rows, err := db.Query(userQuery, uuid.UUID(id).String()) + rows, err := d.db.Query(userQuery, uuid.UUID(id).String()) for rows.Next() { var userId string var alphaBytes []byte @@ -299,7 +255,7 @@ SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHER AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, - UserKeys: m.UserCipherKeys{ + CipherKeys: m.UserCipherKeys{ AlphaKey: util.ByteArrToUint64Arr(alphaBytes), SetKey: util.ByteArrToUint64Arr(setBytes), }, @@ -311,29 +267,82 @@ SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHER return err } renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;" - renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.UserKeys.AlphaKey), util.Uint64ArrToByteArr(user.UserKeys.SetKey), 1, userId) + renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) } renewExec += ` COMMIT; ` - _, err = db.Exec(renewExec, renewArgs...) + _, err = d.db.Exec(renewExec, renewArgs...) return err } func (d *SqliteDB) RefreshUser(user m.User, passcodeIdx []int, customerAttr m.CustomerAttributes) error { - db, err := sql.Open("sqlite3", d.path) - if err != nil { - return err - } - defer db.Close() - err = user.RefreshPasscode(passcodeIdx, customerAttr) + err := user.RefreshPasscode(passcodeIdx, customerAttr) if err != nil { return err } updateUser := ` UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; ` - _, err = db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.UserKeys.AlphaKey), util.Uint64ArrToByteArr(user.UserKeys.SetKey), util.Uint64ArrToByteArr(user.UserKeys.PassKey), util.Uint64ArrToByteArr(user.UserKeys.MaskKey), user.UserKeys.Salt, uuid.UUID(user.Id).String()) + _, err = d.db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) return err } + +func (d *SqliteDB) RandomSvgInterface(kp m.KeypadDimension) ([]string, error) { + ids, err := d.getRandomIds(kp.TotalAttrs()) + if err != nil { + return nil, err + } + return d.getSvgsById(ids) +} + +func (d *SqliteDB) RandomSvgIdxInterface(kp m.KeypadDimension) (m.SvgIdInterface, error) { + return d.getRandomIds(kp.TotalAttrs()) +} + +func (d *SqliteDB) GetSvgStringInterface(idxs m.SvgIdInterface) ([]string, error) { + return d.getSvgsById(idxs) +} + +func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { + selectId := "SELECT svg FROM svg_icon where id = ?" + svgs := make([]string, len(ids)) + for idx, id := range ids { + rows, err := d.db.Query(selectId, id) + if err != nil { + return nil, err + } + if !rows.Next() { + return nil, errors.New(fmt.Sprintf("id not found: %d", id)) + } + err = rows.Scan(&svgs[idx]) + if err != nil { + return nil, err + } + } + return svgs, nil +} + +func (d *SqliteDB) getRandomIds(count int) ([]int, error) { + rows, err := d.db.Query("SELECT COUNT(*) as count FROM svg_icon;") + if err != nil { + return nil, err + } + var tableLen int + if !rows.Next() { + return nil, errors.New("empty svg_icon table") + } + err = rows.Scan(&tableLen) + if err != nil { + return nil, err + } + perm, err := util.RandomPermutation(tableLen) + for idx := range perm { + perm[idx] += 1 + } + if err != nil { + return nil, err + } + return perm[:count], nil +} diff --git a/core/nkode/sqlite_db_test.go b/core/nkode/sqlite_db_test.go index 552b093..15ee21f 100644 --- a/core/nkode/sqlite_db_test.go +++ b/core/nkode/sqlite_db_test.go @@ -3,14 +3,26 @@ package nkode import ( "github.com/stretchr/testify/assert" m "go-nkode/core/model" - "os" "testing" ) func TestNewSqliteDB(t *testing.T) { - dbFile := "test.db" - db, err := NewSqliteDB(dbFile) - assert.NoError(t, err) + dbFile := "../../test.db" + // sql_driver.MakeTables(dbFile) + db := NewSqliteDB(dbFile) + defer db.CloseDb() + + testSignupLoginRenew(t, db) + testSqliteDBRandomSvgInterface(t, db) + // if _, err := os.Stat(dbFile); err == nil { + // err = os.Remove(dbFile) + // assert.NoError(t, err) + // } else { + // assert.NoError(t, err) + // } +} + +func testSignupLoginRenew(t *testing.T, db DbAccessor) { nkode_policy := m.NewDefaultNKodePolicy() customerOrig, err := m.NewCustomer(nkode_policy) assert.NoError(t, err) @@ -22,7 +34,8 @@ func TestNewSqliteDB(t *testing.T) { username := m.Username("test_user") kp := m.KeypadDefault passcodeIdx := []int{0, 1, 2, 3} - ui, err := m.NewUserInterface(&kp) + mockSvgInterface := make(m.SvgIdInterface, kp.TotalAttrs()) + ui, err := m.NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) userOrig, err := NewUser(*customer, username, passcodeIdx, *ui, kp) assert.NoError(t, err) @@ -35,10 +48,11 @@ func TestNewSqliteDB(t *testing.T) { err = db.Renew(customer.Id) assert.NoError(t, err) - if _, err := os.Stat(dbFile); err == nil { - err = os.Remove(dbFile) - assert.NoError(t, err) - } else { - assert.NoError(t, err) - } +} + +func testSqliteDBRandomSvgInterface(t *testing.T, db DbAccessor) { + kp := m.KeypadMax + svgs, err := db.RandomSvgInterface(kp) + assert.NoError(t, err) + assert.Len(t, svgs, kp.TotalAttrs()) } diff --git a/core/nkode/user_signup_session.go b/core/nkode/user_signup_session.go index affd99b..6fdd87e 100644 --- a/core/nkode/user_signup_session.go +++ b/core/nkode/user_signup_session.go @@ -22,8 +22,8 @@ type UserSignSession struct { Expire int } -func NewSignupSession(username m.Username, kp m.KeypadDimension, customerId m.CustomerId) (*UserSignSession, error) { - loginInterface, err := m.NewUserInterface(&kp) +func NewSignupSession(username m.Username, kp m.KeypadDimension, customerId m.CustomerId, svgInterface m.SvgIdInterface) (*UserSignSession, error) { + loginInterface, err := m.NewUserInterface(&kp, svgInterface) if err != nil { return nil, err } diff --git a/core/svg-icon/db_interface.go b/core/svg-icon/db_interface.go new file mode 100644 index 0000000..37603a1 --- /dev/null +++ b/core/svg-icon/db_interface.go @@ -0,0 +1,58 @@ +package svg_icon + +import ( + "database/sql" + "go-nkode/util" +) + +type SvgIcon struct { + Id int + Svg string +} + +type SvgIconDb struct { + path string +} + +func (d *SvgIconDb) GetSvgsById(ids []int) ([]string, error) { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return nil, err + } + defer db.Close() + selectId := "SELECT svg FROM svg_icon where id = ?" + svgs := make([]string, len(ids)) + for idx, id := range ids { + rows, err := db.Query(selectId, id) + if err != nil { + return nil, err + } + err = rows.Scan(&svgs[idx]) + if err != nil { + return nil, err + } + } + return svgs, nil +} + +func (d *SvgIconDb) GetRandomIds(count int) ([]int, error) { + db, err := sql.Open("sqlite3", d.path) + if err != nil { + return nil, err + } + defer db.Close() + rows, err := db.Query("SELECT COUNT(*) FROM svg_icon;") + if err != nil { + return nil, err + } + var tableLen int + err = rows.Scan(&tableLen) + if err != nil { + return nil, err + } + perm, err := util.RandomPermutation(tableLen) + if err != nil { + return nil, err + } + return perm[:count], nil +} diff --git a/go.mod b/go.mod index b9a78fa..770c591 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,15 @@ module go-nkode -go 1.19 +go 1.22.0 + +toolchain go1.23.0 require ( github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.26.0 + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 ) require ( diff --git a/go.sum b/go.sum index 815b869..1d4e786 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 49da9e8..4bda7cc 100644 --- a/main.go +++ b/main.go @@ -12,10 +12,8 @@ import ( func main() { //db := nkode.NewInMemoryDb() - db, err := nkode.NewSqliteDB("nkode.db") - if err != nil { - log.Fatal(err) - } + db := nkode.NewSqliteDB("nkode.db") + defer db.CloseDb() nkodeApi := nkode.NewNKodeAPI(db) AddDefaultCustomer(nkodeApi) handler := model.NKodeHandler{Api: &nkodeApi} @@ -27,6 +25,7 @@ func main() { mux.Handle(api.GetLoginInterface, &handler) mux.Handle(api.Login, &handler) mux.Handle(api.RenewAttributes, &handler) + mux.Handle(api.RandomSvgInterface, &handler) fmt.Println("Running on localhost:8080...") log.Fatal(http.ListenAndServe("localhost:8080", corsMiddleware(mux))) } diff --git a/main_test.go b/main_test.go index 20753fb..b6afe0f 100644 --- a/main_test.go +++ b/main_test.go @@ -21,7 +21,7 @@ func TestApi(t *testing.T) { NumbOfKeys: 10, } var customerResp m.CreateNewCustomerResp - testApiCall(t, base+api.CreateNewCustomer, newCustomerBody, &customerResp) + testApiPost(t, base+api.CreateNewCustomer, newCustomerBody, &customerResp) username := m.Username("test_username") signupInterfaceBody := m.GenerateSignupInterfacePost{ @@ -31,10 +31,10 @@ func TestApi(t *testing.T) { Username: username, } var signupInterfaceResp m.GenerateSignupInterfaceResp - testApiCall(t, base+api.GenerateSignupInterface, signupInterfaceBody, &signupInterfaceResp) - + testApiPost(t, base+api.GenerateSignupInterface, signupInterfaceBody, &signupInterfaceResp) + assert.Len(t, signupInterfaceResp.SvgInterface, kp.TotalAttrs()) passcodeLen := 4 - setInterface := signupInterfaceResp.UserInterface + setInterface := signupInterfaceResp.UserIdxInterface userPasscode := setInterface[:passcodeLen] kp_set := m.KeypadDimension{NumbOfKeys: kp.NumbOfKeys, AttrsPerKey: kp.NumbOfKeys} setKeySelection, err := m.SelectKeyByAttrIdx(setInterface, userPasscode, kp_set) @@ -45,7 +45,7 @@ func TestApi(t *testing.T) { KeySelection: setKeySelection, } var setNKodeResp m.SetNKodeResp - testApiCall(t, base+api.SetNKode, setNKodeBody, &setNKodeResp) + testApiPost(t, base+api.SetNKode, setNKodeBody, &setNKodeResp) confirmInterface := setNKodeResp.UserInterface confirmKeySelection, err := m.SelectKeyByAttrIdx(confirmInterface, userPasscode, kp_set) assert.NoError(t, err) @@ -54,7 +54,7 @@ func TestApi(t *testing.T) { KeySelection: confirmKeySelection, SessionId: signupInterfaceResp.SessionId, } - testApiCall(t, base+api.ConfirmNKode, confirmNKodeBody, nil) + testApiPost(t, base+api.ConfirmNKode, confirmNKodeBody, nil) loginInterfaceBody := m.GetLoginInterfacePost{ CustomerId: customerResp.CustomerId, @@ -62,9 +62,9 @@ func TestApi(t *testing.T) { } var loginInterfaceResp m.GetLoginInterfaceResp - testApiCall(t, base+api.GetLoginInterface, loginInterfaceBody, &loginInterfaceResp) + testApiPost(t, base+api.GetLoginInterface, loginInterfaceBody, &loginInterfaceResp) - loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) + loginKeySelection, err := m.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) assert.NoError(t, err) loginBody := m.LoginPost{ CustomerId: customerResp.CustomerId, @@ -72,12 +72,12 @@ func TestApi(t *testing.T) { KeySelection: loginKeySelection, } - testApiCall(t, base+api.Login, loginBody, nil) + testApiPost(t, base+api.Login, loginBody, nil) renewBody := m.RenewAttributesPost{CustomerId: customerResp.CustomerId} - testApiCall(t, base+api.RenewAttributes, renewBody, nil) + testApiPost(t, base+api.RenewAttributes, renewBody, nil) - loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterfaceResp.UserInterface, userPasscode, kp) + loginKeySelection, err = m.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) assert.NoError(t, err) loginBody = m.LoginPost{ CustomerId: customerResp.CustomerId, @@ -85,7 +85,12 @@ func TestApi(t *testing.T) { KeySelection: loginKeySelection, } - testApiCall(t, base+api.Login, loginBody, nil) + testApiPost(t, base+api.Login, loginBody, nil) + + var randomSvgInterfaceResp m.RandomSvgInterfaceResp + testApiGet(t, base+api.RandomSvgInterface, &randomSvgInterfaceResp) + assert.Equal(t, m.KeypadMax.TotalAttrs(), len(randomSvgInterfaceResp.Svgs)) + } func Unmarshal(t *testing.T, resp *http.Response, data any) { @@ -102,7 +107,7 @@ func Marshal(t *testing.T, data any) *bytes.Reader { return reader } -func testApiCall(t *testing.T, endpointStr string, postBody any, respBody any) { +func testApiPost(t *testing.T, endpointStr string, postBody any, respBody any) { reader := Marshal(t, postBody) resp, err := http.Post(endpointStr, "application/json", reader) assert.NoError(t, err) @@ -111,3 +116,12 @@ func testApiCall(t *testing.T, endpointStr string, postBody any, respBody any) { Unmarshal(t, resp, respBody) } } + +func testApiGet(t *testing.T, endpointStr string, respBody any) { + resp, err := http.Get(endpointStr) + assert.NoError(t, err) + assert.Equal(t, resp.StatusCode, http.StatusOK) + if respBody != nil { + Unmarshal(t, resp, respBody) + } +} diff --git a/sql-driver/sql_driver.go b/sql-driver/sql_driver.go index 441c094..0af4846 100644 --- a/sql-driver/sql_driver.go +++ b/sql-driver/sql_driver.go @@ -2,86 +2,70 @@ package sql_driver import ( "database/sql" - "fmt" - "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "log" ) -func InitTables() { - db, err := sql.Open("sqlite3", "./example.db") +func MakeTables(dbPath string) { + db, err := sql.Open("sqlite3", dbPath) if err != nil { log.Fatal(err) } defer db.Close() + createTable := ` +PRAGMA journal_mode=WAL; +PRAGMA foreign_keys = ON; - // Create a table - createTableSQL := ` - CREATE TABLE IF NOT EXISTS users ( - id TEXT NOT NULL PRIMARY KEY, - name TEXT, - age INTEGER - ); - ` - _, err = db.Exec(createTableSQL) +CREATE TABLE IF NOT EXISTS customer ( + id TEXT NOT NULL PRIMARY KEY, + max_nkode_len INTEGER NOT NULL, + min_nkode_len INTEGER NOT NULL, + distinct_sets INTEGER NOT NULL, + distinct_attributes INTEGER NOT NULL, + lock_out INTEGER NOT NULL, + expiration INTEGER NOT NULL, + attribute_values BLOB NOT NULL, + set_values BLOB NOT NULL +); + +CREATE TABLE IF NOT EXISTS user ( + id TEXT NOT NULL PRIMARY KEY, + username TEXT NOT NULL, + renew INT NOT NULL, + customer_id TEXT NOT NULL, + +-- Enciphered Passcode + code TEXT NOT NULL, + mask TEXT NOT NULL, + +-- Keypad Dimensions + attributes_per_key INT NOT NULL, + number_of_keys INT NOT NULL, + +-- User Keys + alpha_key BLOB NOT NULL, + set_key BLOB NOT NULL, + pass_key BLOB NOT NULL, + mask_key BLOB NOT NULL, + salt BLOB NOT NULL, + max_nkode_len INT NOT NULL, + +-- User Interface + idx_interface BLOB NOT NULL, + svg_id_interface BLOB NOT NULL, + + + FOREIGN KEY (customer_id) REFERENCES customers(id), + UNIQUE(customer_id, username) +); + +CREATE TABLE IF NOT EXISTS svg_icon ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + svg TEXT NOT NULL +); +` + _, err = db.Exec(createTable) if err != nil { - log.Fatalf("Error creating table: %s", err) - } - - // Insert data into the table - insertUserSQL := `INSERT INTO users (id, name, age) VALUES (?, ?, ?)` - _, err = db.Exec(insertUserSQL, uuid.New(), "Alice", 30) - if err != nil { - log.Fatalf("Error inserting data: %s", err) - } - - _, err = db.Exec(insertUserSQL, uuid.New(), "Bob", 25) - if err != nil { - log.Fatalf("Error inserting data: %s", err) - } - - // Query the data - queryUserSQL := `SELECT id, name, age FROM users` - rows, err := db.Query(queryUserSQL) - if err != nil { - log.Fatalf("Error querying data: %s", err) - } - defer rows.Close() - - for rows.Next() { - var id string - var name string - var age int - err = rows.Scan(&id, &name, &age) - if err != nil { - log.Fatalf("Error scanning data: %s", err) - } - fmt.Printf("User: ID=%s, Name=%s, Age=%d\n", id, name, age) - } - - // Update data - updateUserSQL := `UPDATE users SET age = ? WHERE name = ?` - _, err = db.Exec(updateUserSQL, 35, "Alice") - if err != nil { - log.Fatalf("Error updating data: %s", err) - } - - // Verify the update - fmt.Println("After update:") - rows, err = db.Query(queryUserSQL) - if err != nil { - log.Fatalf("Error querying data: %s", err) - } - defer rows.Close() - - for rows.Next() { - var id string - var name string - var age int - err = rows.Scan(&id, &name, &age) - if err != nil { - log.Fatalf("Error scanning data: %s", err) - } - fmt.Printf("User: ID=%s, Name=%s, Age=%d\n", id, name, age) + log.Fatal(err) } } diff --git a/sql-driver/sql_driver_test.go b/sql-driver/sql_driver_test.go deleted file mode 100644 index b3f2fe4..0000000 --- a/sql-driver/sql_driver_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package sql_driver - -import "testing" - -func TestInitTables(t *testing.T) { - InitTables() -} diff --git a/svg-builder/json/academicons.json b/sqlite-init/json/academicons.json similarity index 100% rename from svg-builder/json/academicons.json rename to sqlite-init/json/academicons.json diff --git a/svg-builder/json/akar-icons.json b/sqlite-init/json/akar-icons.json similarity index 100% rename from svg-builder/json/akar-icons.json rename to sqlite-init/json/akar-icons.json diff --git a/svg-builder/json/ant-design.json b/sqlite-init/json/ant-design.json similarity index 100% rename from svg-builder/json/ant-design.json rename to sqlite-init/json/ant-design.json diff --git a/svg-builder/json/arcticons.json b/sqlite-init/json/arcticons.json similarity index 100% rename from svg-builder/json/arcticons.json rename to sqlite-init/json/arcticons.json diff --git a/svg-builder/json/basil.json b/sqlite-init/json/basil.json similarity index 100% rename from svg-builder/json/basil.json rename to sqlite-init/json/basil.json diff --git a/svg-builder/json/bitcoin-icons.json b/sqlite-init/json/bitcoin-icons.json similarity index 100% rename from svg-builder/json/bitcoin-icons.json rename to sqlite-init/json/bitcoin-icons.json diff --git a/svg-builder/svg_builder.go b/sqlite-init/sqlite_init.go similarity index 59% rename from svg-builder/svg_builder.go rename to sqlite-init/sqlite_init.go index 480b806..f8d5435 100644 --- a/svg-builder/svg_builder.go +++ b/sqlite-init/sqlite_init.go @@ -5,10 +5,9 @@ import ( "encoding/json" "fmt" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver - "html/template" + "go-nkode/sql-driver" "io/ioutil" "log" - "os" "path/filepath" "strings" ) @@ -26,28 +25,21 @@ type Root struct { } func main() { + dbPaths := []string{"test.db", "nkode.db"} outputStr := MakeSvgFiles() - SaveToSqlite(outputStr) - + for _, path := range dbPaths { + sql_driver.MakeTables(path) + SaveToSqlite(path, outputStr) + } } -func SaveToSqlite(outputStr string) { - svgIconDbPath := "./svg-builder/svg_icon.db" - db, err := sql.Open("sqlite3", svgIconDbPath) +func SaveToSqlite(dbPath string, outputStr string) { + db, err := sql.Open("sqlite3", dbPath) if err != nil { log.Fatal(err) } defer db.Close() - createTable := ` -CREATE TABLE IF NOT EXISTS svg_icon ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - svg TEXT NOT NULL -); -` - _, err = db.Exec(createTable) - if err != nil { - log.Fatal(err) - } + lines := strings.Split(outputStr, "\n") insertSql := ` INSERT INTO svg_icon (svg) @@ -65,18 +57,8 @@ VALUES (?) } -func SaveToTextFile(outputStr string) { - outputFile := "./output.txt" - file, err := os.Create(outputFile) - file.WriteString(outputStr) - defer file.Close() - if err != nil { - log.Print("Error loading JSON file: ", err) - } -} - func MakeSvgFiles() string { - jsonFiles, err := GetAllFiles("./svg-builder/json") + jsonFiles, err := GetAllFiles("./sqlite-init/json") if err != nil { log.Fatalf("Error getting JSON files: %v", err) } @@ -120,48 +102,6 @@ func MakeSvgFiles() string { return outputStr } -func MakeSVGTemplates() { - // Step 1: Get all JSON files from the ./json folder - jsonFiles, err := GetAllFiles("./json") - if err != nil { - log.Fatalf("Error getting JSON files: %v", err) - } - - // Check if there are any JSON files found - if len(jsonFiles) == 0 { - log.Fatal("No JSON files found in ./json folder") - } - - // Step 2: Load the first JSON file into a Root struct - rootData, err := LoadJson(jsonFiles[0]) - if err != nil { - log.Fatalf("Error loading JSON file: %v", err) - } - - // Step 3: Load the HTML template from the templates directory - tmplPath := filepath.Join("templates", "icon_grid.html") - tmpl, err := template.ParseFiles(tmplPath) - if err != nil { - log.Fatalf("Error loading template file: %v", err) - } - - // Step 4: Create an output file for the rendered HTML - outputFile, err := os.Create("output.html") - if err != nil { - log.Fatalf("Error creating output file: %v", err) - } - defer outputFile.Close() - - // Step 5: Execute the template with the parsed JSON data and write to the output file - err = tmpl.Execute(outputFile, rootData) - if err != nil { - log.Fatalf("Error rendering template: %v", err) - } - - fmt.Println("HTML output has been generated successfully: output.html") - -} - func GetAllFiles(dir string) ([]string, error) { // Use ioutil.ReadDir to list all files in the directory files, err := ioutil.ReadDir(dir) diff --git a/util/util.go b/util/util.go index 94702c7..78948b1 100644 --- a/util/util.go +++ b/util/util.go @@ -60,6 +60,15 @@ func GenerateRandomUInt64() (uint64, error) { return val, nil } +func GenerateRandomInt() (int, error) { + randBytes, err := RandomBytes(8) + if err != nil { + return 0, err + } + val := int(binary.LittleEndian.Uint64(randBytes) & 0x7FFFFFFFFFFFFFFF) // Ensure it's positive + return val, nil +} + func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) { if listLen > int(1)<<32 { return nil, errors.New("list length must be less than 2^32") @@ -80,6 +89,26 @@ func GenerateRandomNonRepeatingUint64(listLen int) ([]uint64, error) { return data, nil } +func GenerateRandomNonRepeatingInt(listLen int) ([]int, error) { + if listLen > int(1)<<31 { + return nil, errors.New("list length must be less than 2^31") + } + listSet := make(hashset.Set[int]) + for { + if listSet.Size() == listLen { + break + } + randNum, err := GenerateRandomInt() + if err != nil { + return nil, err + } + listSet.Add(randNum) + } + + data := listSet.ToSlice() + return data, nil +} + func XorLists(l0 []uint64, l1 []uint64) ([]uint64, error) { if len(l0) != len(l1) { return nil, errors.New(fmt.Sprintf("list len mismatch %d, %d", len(l0), len(l1))) @@ -222,3 +251,13 @@ func Choice[T any](items []T) T { r.Seed(time.Now().UnixNano()) // Seed the random number generator return items[r.Intn(len(items))] } + +// GenerateRandomString creates a random string of a specified length. +func GenerateRandomString(length int) string { + charset := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + b := make([]rune, length) + for i := range b { + b[i] = Choice[rune](charset) + } + return string(b) +} diff --git a/util/util_test.go b/util/util_test.go index fe8ab9a..7c49f21 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -13,6 +13,14 @@ func TestGenerateRandomNonRepeatingUint64(t *testing.T) { assert.Equal(t, len(randNumbs), arrLen) } +func TestGenerateRandomNonRepeatingInt(t *testing.T) { + arrLen := 100000 + randNumbs, err := GenerateRandomNonRepeatingInt(arrLen) + assert.NoError(t, err) + + assert.Equal(t, len(randNumbs), arrLen) +} + func TestEncodeDecode(t *testing.T) { testArr := []uint64{1, 2, 3, 4, 5, 6} testEncode := EncodeBase64Str(testArr)