From c0b785ca8ddd5c2571f9f2e98ce61d0d98d87899 Mon Sep 17 00:00:00 2001 From: Donovan Date: Wed, 27 Nov 2024 09:41:31 -0600 Subject: [PATCH 1/4] more refactoring --- .env.test.example | 1 + cmd/main_test.go | 15 +++---- internal/api/db_interface.go | 20 ---------- internal/api/handler.go | 3 +- internal/api/nkode_api.go | 24 ++++++------ internal/api/nkode_api_test.go | 29 +++++++------- internal/db/customer_user_repository.go | 21 ++++++++++ internal/db/in_memory_db.go | 27 ++++++------- internal/db/sqlite_db.go | 39 ++++++++++--------- internal/db/sqlite_db_test.go | 16 ++++---- internal/email/queue.go | 20 +++++----- internal/{models => entities}/customer.go | 11 +++--- .../customer_attributes.go | 2 +- .../{models => entities}/customer_test.go | 11 +++--- .../{models => entities}/keypad_dimension.go | 2 +- internal/{models => entities}/test_helper.go | 2 +- internal/{models => entities}/user.go | 17 ++++---- .../{models => entities}/user_cipher_keys.go | 7 ++-- .../{models => entities}/user_interface.go | 9 +++-- .../user_signup_session.go | 33 ++++++++-------- internal/{models => entities}/user_test.go | 9 +++-- 21 files changed, 167 insertions(+), 151 deletions(-) create mode 100644 .env.test.example delete mode 100644 internal/api/db_interface.go create mode 100644 internal/db/customer_user_repository.go rename internal/{models => entities}/customer.go (89%) rename internal/{models => entities}/customer_attributes.go (99%) rename internal/{models => entities}/customer_test.go (86%) rename internal/{models => entities}/keypad_dimension.go (98%) rename internal/{models => entities}/test_helper.go (97%) rename internal/{models => entities}/user.go (92%) rename internal/{models => entities}/user_cipher_keys.go (97%) rename internal/{models => entities}/user_interface.go (95%) rename internal/{models => entities}/user_signup_session.go (83%) rename internal/{models => entities}/user_test.go (94%) diff --git a/.env.test.example b/.env.test.example new file mode 100644 index 0000000..c8b012b --- /dev/null +++ b/.env.test.example @@ -0,0 +1 @@ +TEST_DB=/path/to/test.db \ No newline at end of file diff --git a/cmd/main_test.go b/cmd/main_test.go index 0834b66..e20f4c3 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "go-nkode/internal/api" + "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "io" @@ -19,7 +20,7 @@ func TestApi(t *testing.T) { newCustomerBody := models.NewCustomerPost{ NKodePolicy: models.NewDefaultNKodePolicy(), } - kp := models.KeypadDimension{ + kp := entities.KeypadDimension{ AttrsPerKey: 14, NumbOfKeys: 10, } @@ -40,8 +41,8 @@ func TestApi(t *testing.T) { passcodeLen := 4 setInterface := signupInterfaceResp.UserIdxInterface userPasscode := setInterface[:passcodeLen] - kpSet := models.KeypadDimension{NumbOfKeys: kp.NumbOfKeys, AttrsPerKey: kp.NumbOfKeys} - setKeySelection, err := models.SelectKeyByAttrIdx(setInterface, userPasscode, kpSet) + kpSet := entities.KeypadDimension{NumbOfKeys: kp.NumbOfKeys, AttrsPerKey: kp.NumbOfKeys} + setKeySelection, err := entities.SelectKeyByAttrIdx(setInterface, userPasscode, kpSet) assert.NoError(t, err) setNKodeBody := models.SetNKodePost{ CustomerId: customerResp.CustomerId, @@ -51,7 +52,7 @@ func TestApi(t *testing.T) { var setNKodeResp models.SetNKodeResp testApiPost(t, base+api.SetNKode, setNKodeBody, &setNKodeResp) confirmInterface := setNKodeResp.UserInterface - confirmKeySelection, err := models.SelectKeyByAttrIdx(confirmInterface, userPasscode, kpSet) + confirmKeySelection, err := entities.SelectKeyByAttrIdx(confirmInterface, userPasscode, kpSet) assert.NoError(t, err) confirmNKodeBody := models.ConfirmNKodePost{ CustomerId: customerResp.CustomerId, @@ -69,7 +70,7 @@ func TestApi(t *testing.T) { testApiPost(t, base+api.GetLoginInterface, loginInterfaceBody, &loginInterfaceResp) assert.Equal(t, loginInterfaceResp.AttrsPerKey, kp.AttrsPerKey) assert.Equal(t, loginInterfaceResp.NumbOfKeys, kp.NumbOfKeys) - loginKeySelection, err := models.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) + loginKeySelection, err := entities.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) assert.NoError(t, err) loginBody := models.LoginPost{ CustomerId: customerResp.CustomerId, @@ -86,7 +87,7 @@ func TestApi(t *testing.T) { renewBody := models.RenewAttributesPost{CustomerId: customerResp.CustomerId} testApiPost(t, base+api.RenewAttributes, renewBody, nil) - loginKeySelection, err = models.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) + loginKeySelection, err = entities.SelectKeyByAttrIdx(loginInterfaceResp.UserIdxInterface, userPasscode, kp) assert.NoError(t, err) loginBody = models.LoginPost{ CustomerId: customerResp.CustomerId, @@ -98,7 +99,7 @@ func TestApi(t *testing.T) { var randomSvgInterfaceResp models.RandomSvgInterfaceResp testApiGet(t, base+api.RandomSvgInterface, &randomSvgInterfaceResp, "") - assert.Equal(t, models.KeypadMax.TotalAttrs(), len(randomSvgInterfaceResp.Svgs)) + assert.Equal(t, entities.KeypadMax.TotalAttrs(), len(randomSvgInterfaceResp.Svgs)) var refreshTokenResp models.RefreshTokenResp diff --git a/internal/api/db_interface.go b/internal/api/db_interface.go deleted file mode 100644 index 945b5a4..0000000 --- a/internal/api/db_interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package api - -import ( - "go-nkode/internal/models" -) - -type DbAccessor interface { - GetCustomer(models.CustomerId) (*models.Customer, error) - GetUser(models.UserEmail, models.CustomerId) (*models.User, error) - WriteNewCustomer(models.Customer) error - WriteNewUser(models.User) error - UpdateUserNKode(models.User) error - UpdateUserInterface(models.UserId, models.UserInterface) error - UpdateUserRefreshToken(models.UserId, string) error - Renew(models.CustomerId) error - RefreshUserPasscode(models.User, []int, models.CustomerAttributes) error - RandomSvgInterface(models.KeypadDimension) ([]string, error) - RandomSvgIdxInterface(models.KeypadDimension) (models.SvgIdInterface, error) - GetSvgStringInterface(models.SvgIdInterface) ([]string, error) -} diff --git a/internal/api/handler.go b/internal/api/handler.go index cbcea0d..f7dac38 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/google/uuid" "go-nkode/config" + "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "log" @@ -108,7 +109,7 @@ func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter return } - kp := models.KeypadDimension{ + kp := entities.KeypadDimension{ AttrsPerKey: signupResetPost.AttrsPerKey, NumbOfKeys: signupResetPost.NumbOfKeys, } diff --git a/internal/api/nkode_api.go b/internal/api/nkode_api.go index 69fe69a..eeaf9c5 100644 --- a/internal/api/nkode_api.go +++ b/internal/api/nkode_api.go @@ -5,7 +5,9 @@ import ( "github.com/google/uuid" "github.com/patrickmn/go-cache" "go-nkode/config" + "go-nkode/internal/db" "go-nkode/internal/email" + "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "log" @@ -19,12 +21,12 @@ const ( ) type NKodeAPI struct { - Db DbAccessor + Db db.CustomerUserRepository SignupSessionCache *cache.Cache - EmailQueue *email.EmailQueue + EmailQueue *email.Queue } -func NewNKodeAPI(db DbAccessor, queue *email.EmailQueue) NKodeAPI { +func NewNKodeAPI(db db.CustomerUserRepository, queue *email.Queue) NKodeAPI { return NKodeAPI{ Db: db, EmailQueue: queue, @@ -33,7 +35,7 @@ func NewNKodeAPI(db DbAccessor, queue *email.EmailQueue) NKodeAPI { } func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models.CustomerId) (*models.CustomerId, error) { - newCustomer, err := models.NewCustomer(nkodePolicy) + newCustomer, err := entities.NewCustomer(nkodePolicy) if id != nil { newCustomer.Id = *id } @@ -48,7 +50,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models. return &newCustomer.Id, nil } -func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, customerId models.CustomerId, kp models.KeypadDimension, reset bool) (*models.GenerateSignupResetInterfaceResp, error) { +func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, customerId models.CustomerId, kp entities.KeypadDimension, reset bool) (*models.GenerateSignupResetInterfaceResp, error) { user, err := n.Db.GetUser(userEmail, customerId) if err != nil { return nil, err @@ -61,7 +63,7 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, cust if err != nil { return nil, err } - signupSession, err := models.NewSignupResetSession(userEmail, kp, customerId, svgIdxInterface, reset) + signupSession, err := entities.NewSignupResetSession(userEmail, kp, customerId, svgIdxInterface, reset) if err != nil { return nil, err } @@ -94,7 +96,7 @@ func (n *NKodeAPI) SetNKode(customerId models.CustomerId, sessionId models.Sessi log.Printf("session id does not exist %s", sessionId) return nil, config.ErrSignupSessionDNE } - userSession, ok := session.(models.UserSignSession) + userSession, ok := session.(entities.UserSignSession) if !ok { // handle the case where the type assertion fails return nil, config.ErrSignupSessionDNE @@ -113,7 +115,7 @@ func (n *NKodeAPI) ConfirmNKode(customerId models.CustomerId, sessionId models.S log.Printf("session id does not exist %s", sessionId) return config.ErrSignupSessionDNE } - userSession, ok := session.(models.UserSignSession) + userSession, ok := session.(entities.UserSignSession) if !ok { // handle the case where the type assertion fails return config.ErrSignupSessionDNE @@ -129,7 +131,7 @@ func (n *NKodeAPI) ConfirmNKode(customerId models.CustomerId, sessionId models.S if err = customer.IsValidNKode(userSession.Kp, passcode); err != nil { return err } - user, err := models.NewUser(*customer, string(userSession.UserEmail), passcode, userSession.LoginUserInterface, userSession.Kp) + user, err := entities.NewUser(*customer, string(userSession.UserEmail), passcode, userSession.LoginUserInterface, userSession.Kp) if err != nil { return err } @@ -186,7 +188,7 @@ func (n *NKodeAPI) Login(customerId models.CustomerId, userEmail models.UserEmai log.Printf("user %s for customer %s dne", userEmail, customerId) return nil, config.ErrUserForCustomerDNE } - passcode, err := models.ValidKeyEntry(*user, *customer, keySelection) + passcode, err := entities.ValidKeyEntry(*user, *customer, keySelection) if err != nil { return nil, err } @@ -213,7 +215,7 @@ func (n *NKodeAPI) RenewAttributes(customerId models.CustomerId) error { } func (n *NKodeAPI) RandomSvgInterface() ([]string, error) { - return n.Db.RandomSvgInterface(models.KeypadMax) + return n.Db.RandomSvgInterface(entities.KeypadMax) } func (n *NKodeAPI) RefreshToken(userEmail models.UserEmail, customerId models.CustomerId, refreshToken string) (string, error) { diff --git a/internal/api/nkode_api_test.go b/internal/api/nkode_api_test.go index 80d1b1a..ebc217f 100644 --- a/internal/api/nkode_api_test.go +++ b/internal/api/nkode_api_test.go @@ -4,6 +4,7 @@ import ( "github.com/stretchr/testify/assert" "go-nkode/internal/db" "go-nkode/internal/email" + "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "os" @@ -28,7 +29,7 @@ func TestNKodeAPI(t *testing.T) { //} } -func testNKodeAPI(t *testing.T, db DbAccessor) { +func testNKodeAPI(t *testing.T, db db.CustomerUserRepository) { bufferSize := 100 emailsPerSec := 14 testClient := email.TestEmailClient{} @@ -41,7 +42,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { userEmail := models.UserEmail("test_username" + security.GenerateRandomString(12) + "@example.com") passcodeLen := 4 nkodePolicy := models.NewDefaultNKodePolicy() - keypadSize := models.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} + keypadSize := entities.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} nkodeApi := NewNKodeAPI(db, queue) customerId, err := nkodeApi.CreateNewCustomer(nkodePolicy, nil) assert.NoError(t, err) @@ -51,20 +52,20 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { sessionIdStr := signupResponse.SessionId sessionId, err := models.SessionIdFromString(sessionIdStr) assert.NoError(t, err) - keypadSize = models.KeypadDimension{AttrsPerKey: numbOfKeys, NumbOfKeys: numbOfKeys} + keypadSize = entities.KeypadDimension{AttrsPerKey: numbOfKeys, NumbOfKeys: numbOfKeys} userPasscode := setInterface[:passcodeLen] - setKeySelect, err := models.SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) + setKeySelect, err := entities.SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) assert.NoError(t, err) confirmInterface, err := nkodeApi.SetNKode(*customerId, sessionId, setKeySelect) assert.NoError(t, err) - confirmKeySelect, err := models.SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) + confirmKeySelect, err := entities.SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) err = nkodeApi.ConfirmNKode(*customerId, sessionId, confirmKeySelect) assert.NoError(t, err) - keypadSize = models.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} + keypadSize = entities.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} loginInterface, err := nkodeApi.GetLoginInterface(userEmail, *customerId) assert.NoError(t, err) - loginKeySelection, err := models.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) + loginKeySelection, err := entities.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) assert.NoError(t, err) _, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection) assert.NoError(t, err) @@ -74,34 +75,34 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { loginInterface, err = nkodeApi.GetLoginInterface(userEmail, *customerId) assert.NoError(t, err) - loginKeySelection, err = models.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) + loginKeySelection, err = entities.SelectKeyByAttrIdx(loginInterface.UserIdxInterface, userPasscode, keypadSize) assert.NoError(t, err) _, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection) assert.NoError(t, err) /// Reset nKode attrsPerKey = 6 - keypadSize = models.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} + keypadSize = entities.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} resetResponse, err := nkodeApi.GenerateSignupResetInterface(userEmail, *customerId, keypadSize, true) assert.NoError(t, err) setInterface = resetResponse.UserIdxInterface sessionIdStr = resetResponse.SessionId sessionId, err = models.SessionIdFromString(sessionIdStr) assert.NoError(t, err) - keypadSize = models.KeypadDimension{AttrsPerKey: numbOfKeys, NumbOfKeys: numbOfKeys} + keypadSize = entities.KeypadDimension{AttrsPerKey: numbOfKeys, NumbOfKeys: numbOfKeys} userPasscode = setInterface[:passcodeLen] - setKeySelect, err = models.SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) + setKeySelect, err = entities.SelectKeyByAttrIdx(setInterface, userPasscode, keypadSize) assert.NoError(t, err) confirmInterface, err = nkodeApi.SetNKode(*customerId, sessionId, setKeySelect) assert.NoError(t, err) - confirmKeySelect, err = models.SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) + confirmKeySelect, err = entities.SelectKeyByAttrIdx(confirmInterface, userPasscode, keypadSize) err = nkodeApi.ConfirmNKode(*customerId, sessionId, confirmKeySelect) assert.NoError(t, err) - keypadSize = models.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} + keypadSize = entities.KeypadDimension{AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys} loginInterface2, err := nkodeApi.GetLoginInterface(userEmail, *customerId) assert.NoError(t, err) - loginKeySelection, err = models.SelectKeyByAttrIdx(loginInterface2.UserIdxInterface, userPasscode, keypadSize) + loginKeySelection, err = entities.SelectKeyByAttrIdx(loginInterface2.UserIdxInterface, userPasscode, keypadSize) assert.NoError(t, err) _, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection) assert.NoError(t, err) diff --git a/internal/db/customer_user_repository.go b/internal/db/customer_user_repository.go new file mode 100644 index 0000000..4cccc0b --- /dev/null +++ b/internal/db/customer_user_repository.go @@ -0,0 +1,21 @@ +package db + +import ( + "go-nkode/internal/entities" + "go-nkode/internal/models" +) + +type CustomerUserRepository interface { + GetCustomer(models.CustomerId) (*entities.Customer, error) + GetUser(models.UserEmail, models.CustomerId) (*entities.User, error) + WriteNewCustomer(entities.Customer) error + WriteNewUser(entities.User) error + UpdateUserNKode(entities.User) error + UpdateUserInterface(models.UserId, entities.UserInterface) error + UpdateUserRefreshToken(models.UserId, string) error + Renew(models.CustomerId) error + RefreshUserPasscode(entities.User, []int, entities.CustomerAttributes) error + RandomSvgInterface(entities.KeypadDimension) ([]string, error) + RandomSvgIdxInterface(entities.KeypadDimension) (models.SvgIdInterface, error) + GetSvgStringInterface(models.SvgIdInterface) ([]string, error) +} diff --git a/internal/db/in_memory_db.go b/internal/db/in_memory_db.go index 2e3197e..05405fb 100644 --- a/internal/db/in_memory_db.go +++ b/internal/db/in_memory_db.go @@ -3,24 +3,25 @@ package db import ( "errors" "fmt" + "go-nkode/internal/entities" "go-nkode/internal/models" ) type InMemoryDb struct { - Customers map[models.CustomerId]models.Customer - Users map[models.UserId]models.User + Customers map[models.CustomerId]entities.Customer + Users map[models.UserId]entities.User userIdMap map[string]models.UserId } func NewInMemoryDb() InMemoryDb { return InMemoryDb{ - Customers: make(map[models.CustomerId]models.Customer), - Users: make(map[models.UserId]models.User), + Customers: make(map[models.CustomerId]entities.Customer), + Users: make(map[models.UserId]entities.User), userIdMap: make(map[string]models.UserId), } } -func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*models.Customer, error) { +func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*entities.Customer, error) { customer, exists := db.Customers[id] if !exists { return nil, errors.New(fmt.Sprintf("customer %s dne", customer.Id)) @@ -28,7 +29,7 @@ func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*models.Customer, error return &customer, nil } -func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerId) (*models.User, error) { +func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerId) (*entities.User, error) { key := userIdKey(customerId, username) userId, exists := db.userIdMap[key] if !exists { @@ -41,7 +42,7 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo return &user, nil } -func (db *InMemoryDb) WriteNewCustomer(customer models.Customer) error { +func (db *InMemoryDb) WriteNewCustomer(customer entities.Customer) error { _, exists := db.Customers[customer.Id] if exists { @@ -51,7 +52,7 @@ func (db *InMemoryDb) WriteNewCustomer(customer models.Customer) error { return nil } -func (db *InMemoryDb) WriteNewUser(user models.User) error { +func (db *InMemoryDb) WriteNewUser(user entities.User) error { _, exists := db.Customers[user.CustomerId] if !exists { return errors.New(fmt.Sprintf("can't add user %s to customer %s: customer dne", user.Email, user.CustomerId)) @@ -67,11 +68,11 @@ func (db *InMemoryDb) WriteNewUser(user models.User) error { return nil } -func (db *InMemoryDb) UpdateUserNKode(user models.User) error { +func (db *InMemoryDb) UpdateUserNKode(user entities.User) error { return errors.ErrUnsupported } -func (db *InMemoryDb) UpdateUserInterface(userId models.UserId, ui models.UserInterface) error { +func (db *InMemoryDb) UpdateUserInterface(userId models.UserId, ui entities.UserInterface) error { user, exists := db.Users[userId] if !exists { return errors.New(fmt.Sprintf("can't update user %s, dne", user.Id)) @@ -107,7 +108,7 @@ func (db *InMemoryDb) Renew(id models.CustomerId) error { return nil } -func (db *InMemoryDb) RefreshUserPasscode(user models.User, passocode []int, customerAttr models.CustomerAttributes) error { +func (db *InMemoryDb) RefreshUserPasscode(user entities.User, passocode []int, customerAttr entities.CustomerAttributes) error { err := user.RefreshPasscode(passocode, customerAttr) if err != nil { return err @@ -116,11 +117,11 @@ func (db *InMemoryDb) RefreshUserPasscode(user models.User, passocode []int, cus return nil } -func (db *InMemoryDb) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) { +func (db *InMemoryDb) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { return make([]string, kp.TotalAttrs()), nil } -func (db *InMemoryDb) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) { +func (db *InMemoryDb) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { svgs := make(models.SvgIdInterface, kp.TotalAttrs()) for idx := range svgs { svgs[idx] = idx diff --git a/internal/db/sqlite_db.go b/internal/db/sqlite_db.go index a236780..827995d 100644 --- a/internal/db/sqlite_db.go +++ b/internal/db/sqlite_db.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "go-nkode/config" + "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "log" @@ -60,7 +61,7 @@ func (d *SqliteDB) CloseDb() { } } -func (d *SqliteDB) WriteNewCustomer(c models.Customer) error { +func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error { query := ` INSERT INTO customer ( id @@ -85,7 +86,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?) return d.addWriteTx(query, args) } -func (d *SqliteDB) WriteNewUser(u models.User) error { +func (d *SqliteDB) WriteNewUser(u entities.User) error { query := ` INSERT INTO user ( id @@ -128,7 +129,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) return d.addWriteTx(query, args) } -func (d *SqliteDB) UpdateUserNKode(u models.User) error { +func (d *SqliteDB) UpdateUserNKode(u entities.User) error { query := ` UPDATE user SET renew = ? @@ -158,7 +159,7 @@ WHERE email = ? AND customer_id = ? return d.addWriteTx(query, args) } -func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui models.UserInterface) error { +func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { query := ` UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ? ` @@ -219,20 +220,20 @@ WHERE customer_id = ? if err != nil { return err } - user := models.User{ + user := entities.User{ Id: models.UserId{}, CustomerId: models.CustomerId{}, Email: "", EncipheredPasscode: models.EncipheredNKode{}, - Kp: models.KeypadDimension{ + Kp: entities.KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, - CipherKeys: models.UserCipherKeys{ + CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(alphaBytes), SetKey: security.ByteArrToUint64Arr(setBytes), }, - Interface: models.UserInterface{}, + Interface: entities.UserInterface{}, Renew: false, } err = user.RenewKeys(setXor, attrXor) @@ -255,7 +256,7 @@ WHERE id = ?; return d.addWriteTx(renewQuery, renewArgs) } -func (d *SqliteDB) RefreshUserPasscode(user models.User, passcodeIdx []int, customerAttr models.CustomerAttributes) error { +func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { err := user.RefreshPasscode(passcodeIdx, customerAttr) if err != nil { return err @@ -276,7 +277,7 @@ WHERE id = ?; args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()} return d.addWriteTx(query, args) } -func (d *SqliteDB) GetCustomer(id models.CustomerId) (*models.Customer, error) { +func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { tx, err := d.db.Begin() if err != nil { return nil, err @@ -324,7 +325,7 @@ WHERE id = ? if err != nil { return nil, err } - customer := models.Customer{ + customer := entities.Customer{ Id: id, NKodePolicy: models.NKodePolicy{ MaxNkodeLen: maxNKodeLen, @@ -334,7 +335,7 @@ WHERE id = ? LockOut: lockOut, Expiration: expiration, }, - Attributes: models.NewCustomerAttributesFromBytes(attributeValues, setValues), + Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues), } if err = tx.Commit(); err != nil { return nil, err @@ -342,7 +343,7 @@ WHERE id = ? return &customer, nil } -func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*models.User, error) { +func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { tx, err := d.db.Begin() if err != nil { return nil, err @@ -401,7 +402,7 @@ WHERE user.email = ? AND user.customer_id = ? renew = true } - user := models.User{ + user := entities.User{ Id: models.UserId(userId), CustomerId: customerId, Email: email, @@ -409,11 +410,11 @@ WHERE user.email = ? AND user.customer_id = ? Code: code, Mask: mask, }, - Kp: models.KeypadDimension{ + Kp: entities.KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, - CipherKeys: models.UserCipherKeys{ + CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(alphaKey), SetKey: security.ByteArrToUint64Arr(setKey), PassKey: security.ByteArrToUint64Arr(passKey), @@ -422,7 +423,7 @@ WHERE user.email = ? AND user.customer_id = ? MaxNKodeLen: maxNKodeLen, Kp: nil, }, - Interface: models.UserInterface{ + Interface: entities.UserInterface{ IdxInterface: security.ByteArrToIntArr(idxInterface), SvgId: security.ByteArrToIntArr(svgIdInterface), Kp: nil, @@ -438,7 +439,7 @@ WHERE user.email = ? AND user.customer_id = ? return &user, nil } -func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) { +func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { return nil, err @@ -446,7 +447,7 @@ func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, erro return d.getSvgsById(ids) } -func (d *SqliteDB) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) { +func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { return d.getRandomIds(kp.TotalAttrs()) } diff --git a/internal/db/sqlite_db_test.go b/internal/db/sqlite_db_test.go index 82a5cac..9f53ef1 100644 --- a/internal/db/sqlite_db_test.go +++ b/internal/db/sqlite_db_test.go @@ -2,7 +2,7 @@ package db import ( "github.com/stretchr/testify/assert" - "go-nkode/internal/api" + "go-nkode/internal/entities" "go-nkode/internal/models" "os" "testing" @@ -24,9 +24,9 @@ func TestNewSqliteDB(t *testing.T) { // } } -func testSignupLoginRenew(t *testing.T, db api.DbAccessor) { +func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { nkodePolicy := models.NewDefaultNKodePolicy() - customerOrig, err := models.NewCustomer(nkodePolicy) + customerOrig, err := entities.NewCustomer(nkodePolicy) assert.NoError(t, err) err = db.WriteNewCustomer(*customerOrig) assert.NoError(t, err) @@ -34,12 +34,12 @@ func testSignupLoginRenew(t *testing.T, db api.DbAccessor) { assert.NoError(t, err) assert.Equal(t, customerOrig, customer) username := "test_user@example.com" - kp := models.KeypadDefault + kp := entities.KeypadDefault passcodeIdx := []int{0, 1, 2, 3} mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) - ui, err := models.NewUserInterface(&kp, mockSvgInterface) + ui, err := entities.NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) - userOrig, err := models.NewUser(*customer, username, passcodeIdx, *ui, kp) + userOrig, err := entities.NewUser(*customer, username, passcodeIdx, *ui, kp) assert.NoError(t, err) err = db.WriteNewUser(*userOrig) assert.NoError(t, err) @@ -52,8 +52,8 @@ func testSignupLoginRenew(t *testing.T, db api.DbAccessor) { } -func testSqliteDBRandomSvgInterface(t *testing.T, db api.DbAccessor) { - kp := models.KeypadMax +func testSqliteDBRandomSvgInterface(t *testing.T, db CustomerUserRepository) { + kp := entities.KeypadMax svgs, err := db.RandomSvgInterface(kp) assert.NoError(t, err) assert.Len(t, svgs, kp.TotalAttrs()) diff --git a/internal/email/queue.go b/internal/email/queue.go index 3480157..597694b 100644 --- a/internal/email/queue.go +++ b/internal/email/queue.go @@ -14,7 +14,7 @@ import ( "time" ) -type EmailClient interface { +type Client interface { SendEmail(Email) error } @@ -103,22 +103,22 @@ func (s *SESClient) SendEmail(email Email) error { return nil } -// EmailQueue represents the email queue with rate limiting -type EmailQueue struct { +// Queue represents the email queue with rate limiting +type Queue struct { stop bool emailQueue chan Email // Email queue rateLimit <-chan time.Time // Rate limiter - client EmailClient // SES client to send emails + client Client // SES client to send emails wg sync.WaitGroup // To wait for all emails to be processed FailedSendCount int } // NewEmailQueue creates a new rate-limited email queue -func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *EmailQueue { +func NewEmailQueue(bufferSize int, emailsPerSecond int, client Client) *Queue { // Create a ticker that ticks every second to limit the rate of sending emails rateLimit := time.Tick(time.Second / time.Duration(emailsPerSecond)) - return &EmailQueue{ + return &Queue{ stop: false, emailQueue: make(chan Email, bufferSize), rateLimit: rateLimit, @@ -128,7 +128,7 @@ func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *Ema } // AddEmail queues a new email to be sent -func (q *EmailQueue) AddEmail(email Email) { +func (q *Queue) AddEmail(email Email) { if q.stop { log.Printf("email %s with subject %s not add. Stopping queue", email.Recipient, email.Subject) return @@ -138,7 +138,7 @@ func (q *EmailQueue) AddEmail(email Email) { } // Start begins processing the email queue with rate limiting -func (q *EmailQueue) Start() { +func (q *Queue) Start() { q.stop = false // Worker goroutine that processes emails from the queue go func() { @@ -151,7 +151,7 @@ func (q *EmailQueue) Start() { } // sendEmail sends an email using the SES client -func (q *EmailQueue) sendEmail(email Email) { +func (q *Queue) sendEmail(email Email) { if err := q.client.SendEmail(email); err != nil { q.FailedSendCount += 1 log.Printf("Failed to send email to %s: %v\n", email.Recipient, err) @@ -159,7 +159,7 @@ func (q *EmailQueue) sendEmail(email Email) { } // Stop stops the queue after all emails have been processed -func (q *EmailQueue) Stop() { +func (q *Queue) Stop() { q.stop = true // Wait for all emails to be processed q.wg.Wait() diff --git a/internal/models/customer.go b/internal/entities/customer.go similarity index 89% rename from internal/models/customer.go rename to internal/entities/customer.go index 3183368..a111cc4 100644 --- a/internal/models/customer.go +++ b/internal/entities/customer.go @@ -1,25 +1,26 @@ -package models +package entities import ( "github.com/google/uuid" "go-nkode/config" + "go-nkode/internal/models" "go-nkode/internal/security" "go-nkode/internal/utils" ) type Customer struct { - Id CustomerId - NKodePolicy NKodePolicy + Id models.CustomerId + NKodePolicy models.NKodePolicy Attributes CustomerAttributes } -func NewCustomer(nkodePolicy NKodePolicy) (*Customer, error) { +func NewCustomer(nkodePolicy models.NKodePolicy) (*Customer, error) { customerAttrs, err := NewCustomerAttributes() if err != nil { return nil, err } customer := Customer{ - Id: CustomerId(uuid.New()), + Id: models.CustomerId(uuid.New()), NKodePolicy: nkodePolicy, Attributes: *customerAttrs, } diff --git a/internal/models/customer_attributes.go b/internal/entities/customer_attributes.go similarity index 99% rename from internal/models/customer_attributes.go rename to internal/entities/customer_attributes.go index 64827d9..4d3eb5d 100644 --- a/internal/models/customer_attributes.go +++ b/internal/entities/customer_attributes.go @@ -1,4 +1,4 @@ -package models +package entities import ( "go-nkode/internal/security" diff --git a/internal/models/customer_test.go b/internal/entities/customer_test.go similarity index 86% rename from internal/models/customer_test.go rename to internal/entities/customer_test.go index 1b15fdb..0e98034 100644 --- a/internal/models/customer_test.go +++ b/internal/entities/customer_test.go @@ -1,7 +1,8 @@ -package models +package entities import ( "github.com/stretchr/testify/assert" + "go-nkode/internal/models" "testing" ) @@ -18,10 +19,10 @@ func testNewCustomerAttributes(t *testing.T) { func testCustomerValidKeyEntry(t *testing.T) { kp := KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 9} - nkodePolicy := NewDefaultNKodePolicy() + nkodePolicy := models.NewDefaultNKodePolicy() customer, err := NewCustomer(nkodePolicy) assert.NoError(t, err) - mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) userEmail := "testing@example.com" @@ -42,10 +43,10 @@ func testCustomerValidKeyEntry(t *testing.T) { func testCustomerIsValidNKode(t *testing.T) { kp := KeypadDimension{AttrsPerKey: 10, NumbOfKeys: 7} - nkodePolicy := NewDefaultNKodePolicy() + nkodePolicy := models.NewDefaultNKodePolicy() customer, err := NewCustomer(nkodePolicy) assert.NoError(t, err) - mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) userEmail := "testing123@example.com" diff --git a/internal/models/keypad_dimension.go b/internal/entities/keypad_dimension.go similarity index 98% rename from internal/models/keypad_dimension.go rename to internal/entities/keypad_dimension.go index 75ccfed..3ffd8e4 100644 --- a/internal/models/keypad_dimension.go +++ b/internal/entities/keypad_dimension.go @@ -1,4 +1,4 @@ -package models +package entities import ( "go-nkode/config" diff --git a/internal/models/test_helper.go b/internal/entities/test_helper.go similarity index 97% rename from internal/models/test_helper.go rename to internal/entities/test_helper.go index 33cea57..2874516 100644 --- a/internal/models/test_helper.go +++ b/internal/entities/test_helper.go @@ -1,4 +1,4 @@ -package models +package entities import ( "errors" diff --git a/internal/models/user.go b/internal/entities/user.go similarity index 92% rename from internal/models/user.go rename to internal/entities/user.go index 86847e3..ec89efe 100644 --- a/internal/models/user.go +++ b/internal/entities/user.go @@ -1,17 +1,18 @@ -package models +package entities import ( "github.com/google/uuid" "go-nkode/config" + "go-nkode/internal/models" "go-nkode/internal/security" "log" ) type User struct { - Id UserId - CustomerId CustomerId - Email UserEmail - EncipheredPasscode EncipheredNKode + Id models.UserId + CustomerId models.CustomerId + Email models.UserEmail + EncipheredPasscode models.EncipheredNKode Kp KeypadDimension CipherKeys UserCipherKeys Interface UserInterface @@ -116,7 +117,7 @@ func ValidKeyEntry(user User, customer Customer, selectedKeys []int) ([]int, err } func NewUser(customer Customer, userEmail string, passcodeIdx []int, ui UserInterface, kp KeypadDimension) (*User, error) { - _, err := ParseEmail(userEmail) + _, err := models.ParseEmail(userEmail) if err != nil { return nil, err } @@ -133,8 +134,8 @@ func NewUser(customer Customer, userEmail string, passcodeIdx []int, ui UserInte return nil, err } newUser := User{ - Id: UserId(uuid.New()), - Email: UserEmail(userEmail), + Id: models.UserId(uuid.New()), + Email: models.UserEmail(userEmail), EncipheredPasscode: *encipheredNKode, CipherKeys: *newKeys, Interface: ui, diff --git a/internal/models/user_cipher_keys.go b/internal/entities/user_cipher_keys.go similarity index 97% rename from internal/models/user_cipher_keys.go rename to internal/entities/user_cipher_keys.go index 7f187a9..8168c61 100644 --- a/internal/models/user_cipher_keys.go +++ b/internal/entities/user_cipher_keys.go @@ -1,9 +1,10 @@ -package models +package entities import ( "crypto/sha256" "errors" "go-nkode/config" + "go-nkode/internal/models" "go-nkode/internal/security" "golang.org/x/crypto/bcrypt" ) @@ -166,7 +167,7 @@ func (u *UserCipherKeys) DecipherMask(mask string, setVals []uint64, passcodeLen return passcodeSet, nil } -func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs CustomerAttributes) (*EncipheredNKode, error) { +func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs CustomerAttributes) (*models.EncipheredNKode, error) { attrVals, err := customerAttrs.AttrValsForKp(*u.Kp) code, err := u.EncipherSaltHashCode(passcodeAttrIdx, attrVals) if err != nil { @@ -185,7 +186,7 @@ func (u *UserCipherKeys) EncipherNKode(passcodeAttrIdx []int, customerAttrs Cust if err != nil { return nil, err } - encipheredCode := EncipheredNKode{ + encipheredCode := models.EncipheredNKode{ Code: code, Mask: mask, } diff --git a/internal/models/user_interface.go b/internal/entities/user_interface.go similarity index 95% rename from internal/models/user_interface.go rename to internal/entities/user_interface.go index d5ba0a3..0a691e8 100644 --- a/internal/models/user_interface.go +++ b/internal/entities/user_interface.go @@ -1,19 +1,20 @@ -package models +package entities import ( "go-nkode/config" + "go-nkode/internal/models" "go-nkode/internal/security" "go-nkode/internal/utils" "log" ) type UserInterface struct { - IdxInterface IdxInterface - SvgId SvgIdInterface + IdxInterface models.IdxInterface + SvgId models.SvgIdInterface Kp *KeypadDimension } -func NewUserInterface(kp *KeypadDimension, svgId SvgIdInterface) (*UserInterface, error) { +func NewUserInterface(kp *KeypadDimension, svgId models.SvgIdInterface) (*UserInterface, error) { idxInterface := security.IdentityArray(kp.TotalAttrs()) userInterface := UserInterface{ IdxInterface: idxInterface, diff --git a/internal/models/user_signup_session.go b/internal/entities/user_signup_session.go similarity index 83% rename from internal/models/user_signup_session.go rename to internal/entities/user_signup_session.go index 26692f8..de46146 100644 --- a/internal/models/user_signup_session.go +++ b/internal/entities/user_signup_session.go @@ -1,8 +1,9 @@ -package models +package entities import ( "github.com/google/uuid" "go-nkode/config" + "go-nkode/internal/models" "go-nkode/internal/security" py "go-nkode/internal/utils" "log" @@ -10,20 +11,20 @@ import ( ) type UserSignSession struct { - Id SessionId - CustomerId CustomerId + Id models.SessionId + CustomerId models.CustomerId LoginUserInterface UserInterface Kp KeypadDimension - SetIdxInterface IdxInterface - ConfirmIdxInterface IdxInterface - SetKeySelection KeySelection - UserEmail UserEmail + SetIdxInterface models.IdxInterface + ConfirmIdxInterface models.IdxInterface + SetKeySelection models.KeySelection + UserEmail models.UserEmail Reset bool Expire int - Colors []RGBColor + Colors []models.RGBColor } -func NewSignupResetSession(userEmail UserEmail, kp KeypadDimension, customerId CustomerId, svgInterface SvgIdInterface, reset bool) (*UserSignSession, error) { +func NewSignupResetSession(userEmail models.UserEmail, kp KeypadDimension, customerId models.CustomerId, svgInterface models.SvgIdInterface, reset bool) (*UserSignSession, error) { loginInterface, err := NewUserInterface(&kp, svgInterface) if err != nil { return nil, err @@ -33,7 +34,7 @@ func NewSignupResetSession(userEmail UserEmail, kp KeypadDimension, customerId C return nil, err } session := UserSignSession{ - Id: SessionId(uuid.New()), + Id: models.SessionId(uuid.New()), CustomerId: customerId, LoginUserInterface: *loginInterface, SetIdxInterface: signupInterface.IdxInterface, @@ -48,7 +49,7 @@ func NewSignupResetSession(userEmail UserEmail, kp KeypadDimension, customerId C return &session, nil } -func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, error) { +func (s *UserSignSession) DeducePasscode(confirmKeyEntry models.KeySelection) ([]int, error) { validEntry := py.All[int](confirmKeyEntry, func(i int) bool { return 0 <= i && i < s.Kp.NumbOfKeys }) @@ -109,7 +110,7 @@ func (s *UserSignSession) DeducePasscode(confirmKeyEntry KeySelection) ([]int, e return passcode, nil } -func (s *UserSignSession) SetUserNKode(keySelection KeySelection) (IdxInterface, error) { +func (s *UserSignSession) SetUserNKode(keySelection models.KeySelection) (models.IdxInterface, error) { validKeySelection := py.All[int](keySelection, func(i int) bool { return 0 <= i && i < s.Kp.NumbOfKeys }) @@ -129,7 +130,7 @@ func (s *UserSignSession) SetUserNKode(keySelection KeySelection) (IdxInterface, return s.ConfirmIdxInterface, nil } -func (s *UserSignSession) getSelectedKeyVals(keySelections KeySelection, userInterface []int) ([][]int, error) { +func (s *UserSignSession) getSelectedKeyVals(keySelections models.KeySelection, userInterface []int) ([][]int, error) { signupKp := s.SignupKeypad() keypadInterface, err := security.ListToMatrix(userInterface, signupKp.AttrsPerKey) if err != nil { @@ -143,7 +144,7 @@ func (s *UserSignSession) getSelectedKeyVals(keySelections KeySelection, userInt return keyVals, nil } -func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, []RGBColor, error) { +func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*UserInterface, []models.RGBColor, error) { // This method randomly drops sets from the base user interface so it is a square and dispersable matrix if kp.IsDispersable() { return nil, nil, config.ErrKeypadIsNotDispersible @@ -170,11 +171,11 @@ func signupInterface(baseUserInterface UserInterface, kp KeypadDimension) (*User setIdxs = setIdxs[:kp.NumbOfKeys] sort.Ints(setIdxs) selectedSets := make([][]int, kp.NumbOfKeys) - selectedColors := make([]RGBColor, kp.NumbOfKeys) + selectedColors := make([]models.RGBColor, kp.NumbOfKeys) for idx, setIdx := range setIdxs { selectedSets[idx] = attrSetView[setIdx] - selectedColors[idx] = SetColors[setIdx] + selectedColors[idx] = models.SetColors[setIdx] } // convert set view back into key view selectedSets, err = security.MatrixTranspose(selectedSets) diff --git a/internal/models/user_test.go b/internal/entities/user_test.go similarity index 94% rename from internal/models/user_test.go rename to internal/entities/user_test.go index 3202a0a..9d93c5c 100644 --- a/internal/models/user_test.go +++ b/internal/entities/user_test.go @@ -1,7 +1,8 @@ -package models +package entities import ( "github.com/stretchr/testify/assert" + "go-nkode/internal/models" py "go-nkode/internal/utils" "testing" ) @@ -64,7 +65,7 @@ func TestUserInterface_RandomShuffle(t *testing.T) { AttrsPerKey: 10, NumbOfKeys: 8, } - mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) userInterfaceCopy := make([]int, len(userInterface.IdxInterface)) @@ -87,7 +88,7 @@ func TestUserInterface_DisperseInterface(t *testing.T) { for idx := 0; idx < 10000; idx++ { kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} - mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) preDispersion, err := userInterface.AttributeAdjacencyGraph() @@ -106,7 +107,7 @@ func TestUserInterface_DisperseInterface(t *testing.T) { func TestUserInterface_PartialInterfaceShuffle(t *testing.T) { kp := KeypadDimension{AttrsPerKey: 7, NumbOfKeys: 10} - mockSvgInterface := make(SvgIdInterface, kp.TotalAttrs()) + mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs()) userInterface, err := NewUserInterface(&kp, mockSvgInterface) assert.NoError(t, err) preShuffle := userInterface.IdxInterface From 69ec9bd08cd613ede603ada72f8a404ddf456f75 Mon Sep 17 00:00:00 2001 From: Donovan Date: Tue, 3 Dec 2024 16:22:03 -0600 Subject: [PATCH 2/4] sqlc generate --- internal/sqlc/db.go | 31 +++ internal/sqlc/models.go | 50 ++++ internal/sqlc/query.sql.go | 478 +++++++++++++++++++++++++++++++++++++ sqlc.yaml | 9 + sqlite/query.sql | 136 +++++++++++ sqlite/schema.sql | 57 +++++ 6 files changed, 761 insertions(+) create mode 100644 internal/sqlc/db.go create mode 100644 internal/sqlc/models.go create mode 100644 internal/sqlc/query.sql.go create mode 100644 sqlc.yaml create mode 100644 sqlite/query.sql create mode 100644 sqlite/schema.sql diff --git a/internal/sqlc/db.go b/internal/sqlc/db.go new file mode 100644 index 0000000..2248616 --- /dev/null +++ b/internal/sqlc/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package sqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go new file mode 100644 index 0000000..d8f68cf --- /dev/null +++ b/internal/sqlc/models.go @@ -0,0 +1,50 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package sqlc + +import ( + "database/sql" +) + +type Customer struct { + ID string + MaxNkodeLen int64 + MinNkodeLen int64 + DistinctSets int64 + DistinctAttributes int64 + LockOut int64 + Expiration int64 + AttributeValues []byte + SetValues []byte + LastRenew string + CreatedAt string +} + +type SvgIcon struct { + ID int64 + Svg string +} + +type User struct { + ID string + Email string + Renew int64 + RefreshToken sql.NullString + CustomerID string + Code string + Mask string + AttributesPerKey int64 + NumberOfKeys int64 + AlphaKey []byte + SetKey []byte + PassKey []byte + MaskKey []byte + Salt []byte + MaxNkodeLen int64 + IdxInterface []byte + SvgIDInterface []byte + LastLogin interface{} + CreatedAt sql.NullString +} diff --git a/internal/sqlc/query.sql.go b/internal/sqlc/query.sql.go new file mode 100644 index 0000000..c11071c --- /dev/null +++ b/internal/sqlc/query.sql.go @@ -0,0 +1,478 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: query.sql + +package sqlc + +import ( + "context" + "database/sql" +) + +const createCustomer = `-- name: CreateCustomer :exec +INSERT INTO customer ( + id + ,max_nkode_len + ,min_nkode_len + ,distinct_sets + ,distinct_attributes + ,lock_out + ,expiration + ,attribute_values + ,set_values + ,last_renew + ,created_at +) +VALUES (?,?,?,?,?,?,?,?,?,?,?) +` + +type CreateCustomerParams struct { + ID string + MaxNkodeLen int64 + MinNkodeLen int64 + DistinctSets int64 + DistinctAttributes int64 + LockOut int64 + Expiration int64 + AttributeValues []byte + SetValues []byte + LastRenew string + CreatedAt string +} + +func (q *Queries) CreateCustomer(ctx context.Context, arg CreateCustomerParams) error { + _, err := q.db.ExecContext(ctx, createCustomer, + arg.ID, + arg.MaxNkodeLen, + arg.MinNkodeLen, + arg.DistinctSets, + arg.DistinctAttributes, + arg.LockOut, + arg.Expiration, + arg.AttributeValues, + arg.SetValues, + arg.LastRenew, + arg.CreatedAt, + ) + return err +} + +const createUser = `-- name: CreateUser :exec +INSERT INTO user ( + id + ,email + ,renew + ,refresh_token + ,customer_id + ,code + ,mask + ,attributes_per_key + ,number_of_keys + ,alpha_key + ,set_key + ,pass_key + ,mask_key + ,salt + ,max_nkode_len + ,idx_interface + ,svg_id_interface + ,created_at +) +VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) +` + +type CreateUserParams struct { + ID string + Email string + Renew int64 + RefreshToken sql.NullString + CustomerID string + Code string + Mask string + AttributesPerKey int64 + NumberOfKeys int64 + AlphaKey []byte + SetKey []byte + PassKey []byte + MaskKey []byte + Salt []byte + MaxNkodeLen int64 + IdxInterface []byte + SvgIDInterface []byte + CreatedAt sql.NullString +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) error { + _, err := q.db.ExecContext(ctx, createUser, + arg.ID, + arg.Email, + arg.Renew, + arg.RefreshToken, + arg.CustomerID, + arg.Code, + arg.Mask, + arg.AttributesPerKey, + arg.NumberOfKeys, + arg.AlphaKey, + arg.SetKey, + arg.PassKey, + arg.MaskKey, + arg.Salt, + arg.MaxNkodeLen, + arg.IdxInterface, + arg.SvgIDInterface, + arg.CreatedAt, + ) + return err +} + +const getCustomer = `-- name: GetCustomer :one +SELECT + max_nkode_len + ,min_nkode_len + ,distinct_sets + ,distinct_attributes + ,lock_out + ,expiration + ,attribute_values + ,set_values +FROM customer +WHERE id = ? +` + +type GetCustomerRow struct { + MaxNkodeLen int64 + MinNkodeLen int64 + DistinctSets int64 + DistinctAttributes int64 + LockOut int64 + Expiration int64 + AttributeValues []byte + SetValues []byte +} + +func (q *Queries) GetCustomer(ctx context.Context, id string) (GetCustomerRow, error) { + row := q.db.QueryRowContext(ctx, getCustomer, id) + var i GetCustomerRow + err := row.Scan( + &i.MaxNkodeLen, + &i.MinNkodeLen, + &i.DistinctSets, + &i.DistinctAttributes, + &i.LockOut, + &i.Expiration, + &i.AttributeValues, + &i.SetValues, + ) + return i, err +} + +const getSvgCount = `-- name: GetSvgCount :one +SELECT COUNT(*) as count FROM svg_icon +` + +func (q *Queries) GetSvgCount(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, getSvgCount) + var count int64 + err := row.Scan(&count) + return count, err +} + +const getSvgId = `-- name: GetSvgId :one +SELECT svg +FROM svg_icon +WHERE id = ? +` + +func (q *Queries) GetSvgId(ctx context.Context, id int64) (string, error) { + row := q.db.QueryRowContext(ctx, getSvgId, id) + var svg string + err := row.Scan(&svg) + return svg, err +} + +const getUser = `-- name: GetUser :one +SELECT + id + ,renew + ,refresh_token + ,code + ,mask + ,attributes_per_key + ,number_of_keys + ,alpha_key + ,set_key + ,pass_key + ,mask_key + ,salt + ,max_nkode_len + ,idx_interface + ,svg_id_interface +FROM user +WHERE user.email = ? AND user.customer_id = ? +` + +type GetUserParams struct { + Email string + CustomerID string +} + +type GetUserRow struct { + ID string + Renew int64 + RefreshToken sql.NullString + Code string + Mask string + AttributesPerKey int64 + NumberOfKeys int64 + AlphaKey []byte + SetKey []byte + PassKey []byte + MaskKey []byte + Salt []byte + MaxNkodeLen int64 + IdxInterface []byte + SvgIDInterface []byte +} + +func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) (GetUserRow, error) { + row := q.db.QueryRowContext(ctx, getUser, arg.Email, arg.CustomerID) + var i GetUserRow + err := row.Scan( + &i.ID, + &i.Renew, + &i.RefreshToken, + &i.Code, + &i.Mask, + &i.AttributesPerKey, + &i.NumberOfKeys, + &i.AlphaKey, + &i.SetKey, + &i.PassKey, + &i.MaskKey, + &i.Salt, + &i.MaxNkodeLen, + &i.IdxInterface, + &i.SvgIDInterface, + ) + return i, err +} + +const getUserRenew = `-- name: GetUserRenew :many +SELECT + id + ,alpha_key + ,set_key + ,attributes_per_key + ,number_of_keys +FROM user +WHERE customer_id = ? +` + +type GetUserRenewRow struct { + ID string + AlphaKey []byte + SetKey []byte + AttributesPerKey int64 + NumberOfKeys int64 +} + +func (q *Queries) GetUserRenew(ctx context.Context, customerID string) ([]GetUserRenewRow, error) { + rows, err := q.db.QueryContext(ctx, getUserRenew, customerID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetUserRenewRow + for rows.Next() { + var i GetUserRenewRow + if err := rows.Scan( + &i.ID, + &i.AlphaKey, + &i.SetKey, + &i.AttributesPerKey, + &i.NumberOfKeys, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const refreshUserPasscode = `-- name: RefreshUserPasscode :exec +UPDATE user +SET + renew = ? + ,code = ? + ,mask = ? + ,alpha_key = ? + ,set_key = ? + ,pass_key = ? + ,mask_key = ? + ,salt = ? +WHERE id = ? +` + +type RefreshUserPasscodeParams struct { + Renew int64 + Code string + Mask string + AlphaKey []byte + SetKey []byte + PassKey []byte + MaskKey []byte + Salt []byte + ID string +} + +func (q *Queries) RefreshUserPasscode(ctx context.Context, arg RefreshUserPasscodeParams) error { + _, err := q.db.ExecContext(ctx, refreshUserPasscode, + arg.Renew, + arg.Code, + arg.Mask, + arg.AlphaKey, + arg.SetKey, + arg.PassKey, + arg.MaskKey, + arg.Salt, + arg.ID, + ) + return err +} + +const renewCustomer = `-- name: RenewCustomer :exec +UPDATE customer +SET attribute_values = ?, set_values = ? +WHERE id = ? +` + +type RenewCustomerParams struct { + AttributeValues []byte + SetValues []byte + ID string +} + +func (q *Queries) RenewCustomer(ctx context.Context, arg RenewCustomerParams) error { + _, err := q.db.ExecContext(ctx, renewCustomer, arg.AttributeValues, arg.SetValues, arg.ID) + return err +} + +const renewUser = `-- name: RenewUser :exec +UPDATE user +SET alpha_key = ?, set_key = ?, renew = ? +WHERE id = ? +` + +type RenewUserParams struct { + AlphaKey []byte + SetKey []byte + Renew int64 + ID string +} + +func (q *Queries) RenewUser(ctx context.Context, arg RenewUserParams) error { + _, err := q.db.ExecContext(ctx, renewUser, + arg.AlphaKey, + arg.SetKey, + arg.Renew, + arg.ID, + ) + return err +} + +const updateUser = `-- name: UpdateUser :exec +UPDATE user +SET renew = ? + ,refresh_token = ? + ,code = ? + ,mask = ? + ,attributes_per_key = ? + ,number_of_keys = ? + ,alpha_key = ? + ,set_key = ? + ,pass_key = ? + ,mask_key = ? + ,salt = ? + ,max_nkode_len = ? + ,idx_interface = ? + ,svg_id_interface = ? +WHERE email = ? AND customer_id = ? +` + +type UpdateUserParams struct { + Renew int64 + RefreshToken sql.NullString + Code string + Mask string + AttributesPerKey int64 + NumberOfKeys int64 + AlphaKey []byte + SetKey []byte + PassKey []byte + MaskKey []byte + Salt []byte + MaxNkodeLen int64 + IdxInterface []byte + SvgIDInterface []byte + Email string + CustomerID string +} + +func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) error { + _, err := q.db.ExecContext(ctx, updateUser, + arg.Renew, + arg.RefreshToken, + arg.Code, + arg.Mask, + arg.AttributesPerKey, + arg.NumberOfKeys, + arg.AlphaKey, + arg.SetKey, + arg.PassKey, + arg.MaskKey, + arg.Salt, + arg.MaxNkodeLen, + arg.IdxInterface, + arg.SvgIDInterface, + arg.Email, + arg.CustomerID, + ) + return err +} + +const updateUserInterface = `-- name: UpdateUserInterface :exec +UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ? +` + +type UpdateUserInterfaceParams struct { + IdxInterface []byte + LastLogin interface{} + ID string +} + +func (q *Queries) UpdateUserInterface(ctx context.Context, arg UpdateUserInterfaceParams) error { + _, err := q.db.ExecContext(ctx, updateUserInterface, arg.IdxInterface, arg.LastLogin, arg.ID) + return err +} + +const updateUserRefreshToken = `-- name: UpdateUserRefreshToken :exec +UPDATE user SET refresh_token = ? WHERE id = ? +` + +type UpdateUserRefreshTokenParams struct { + RefreshToken sql.NullString + ID string +} + +func (q *Queries) UpdateUserRefreshToken(ctx context.Context, arg UpdateUserRefreshTokenParams) error { + _, err := q.db.ExecContext(ctx, updateUserRefreshToken, arg.RefreshToken, arg.ID) + return err +} diff --git a/sqlc.yaml b/sqlc.yaml new file mode 100644 index 0000000..eda576e --- /dev/null +++ b/sqlc.yaml @@ -0,0 +1,9 @@ +version: "2" +sql: + - engine: "sqlite" + queries: "./sqlite/query.sql" + schema: "./sqlite/schema.sql" + gen: + go: + package: "sqlc" + out: "./internal/sqlc" \ No newline at end of file diff --git a/sqlite/query.sql b/sqlite/query.sql new file mode 100644 index 0000000..b3c50cf --- /dev/null +++ b/sqlite/query.sql @@ -0,0 +1,136 @@ +-- name: CreateCustomer :exec +INSERT INTO customer ( + id + ,max_nkode_len + ,min_nkode_len + ,distinct_sets + ,distinct_attributes + ,lock_out + ,expiration + ,attribute_values + ,set_values + ,last_renew + ,created_at +) +VALUES (?,?,?,?,?,?,?,?,?,?,?); + +-- name: CreateUser :exec +INSERT INTO user ( + id + ,email + ,renew + ,refresh_token + ,customer_id + ,code + ,mask + ,attributes_per_key + ,number_of_keys + ,alpha_key + ,set_key + ,pass_key + ,mask_key + ,salt + ,max_nkode_len + ,idx_interface + ,svg_id_interface + ,created_at +) +VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); + +-- name: UpdateUser :exec +UPDATE user +SET renew = ? + ,refresh_token = ? + ,code = ? + ,mask = ? + ,attributes_per_key = ? + ,number_of_keys = ? + ,alpha_key = ? + ,set_key = ? + ,pass_key = ? + ,mask_key = ? + ,salt = ? + ,max_nkode_len = ? + ,idx_interface = ? + ,svg_id_interface = ? +WHERE email = ? AND customer_id = ?; + +-- name: UpdateUserInterface :exec +UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ?; + +-- name: UpdateUserRefreshToken :exec +UPDATE user SET refresh_token = ? WHERE id = ?; + +-- name: RenewCustomer :exec +UPDATE customer +SET attribute_values = ?, set_values = ? +WHERE id = ?; + +-- name: RenewUser :exec +UPDATE user +SET alpha_key = ?, set_key = ?, renew = ? +WHERE id = ?; + +-- name: RefreshUserPasscode :exec +UPDATE user +SET + renew = ? + ,code = ? + ,mask = ? + ,alpha_key = ? + ,set_key = ? + ,pass_key = ? + ,mask_key = ? + ,salt = ? +WHERE id = ?; + +-- name: GetUserRenew :many +SELECT + id + ,alpha_key + ,set_key + ,attributes_per_key + ,number_of_keys +FROM user +WHERE customer_id = ?; + +-- name: GetCustomer :one +SELECT + max_nkode_len + ,min_nkode_len + ,distinct_sets + ,distinct_attributes + ,lock_out + ,expiration + ,attribute_values + ,set_values +FROM customer +WHERE id = ?; + +-- name: GetUser :one +SELECT + id + ,renew + ,refresh_token + ,code + ,mask + ,attributes_per_key + ,number_of_keys + ,alpha_key + ,set_key + ,pass_key + ,mask_key + ,salt + ,max_nkode_len + ,idx_interface + ,svg_id_interface +FROM user +WHERE user.email = ? AND user.customer_id = ?; + +-- name: GetSvgId :one +SELECT svg +FROM svg_icon +WHERE id = ?; + +-- name: GetSvgCount :one +SELECT COUNT(*) as count FROM svg_icon; diff --git a/sqlite/schema.sql b/sqlite/schema.sql new file mode 100644 index 0000000..4e30249 --- /dev/null +++ b/sqlite/schema.sql @@ -0,0 +1,57 @@ +PRAGMA journal_mode=WAL; + + +CREATE TABLE IF NOT EXISTS customer ( + id TEXT NOT NULL PRIMARY KEY + ,max_nkode_len INTEGER NOT NULL + ,min_nkode_len INTEGER NOT NULL + ,distinct_sets INTEGER NOT NULL + ,distinct_attributes INTEGER NOT NULL + ,lock_out INTEGER NOT NULL + ,expiration INTEGER NOT NULL + ,attribute_values BLOB NOT NULL + ,set_values BLOB NOT NULL + ,last_renew TEXT NOT NULL + ,created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS user ( + id TEXT NOT NULL PRIMARY KEY + ,email TEXT NOT NULL +-- first_name TEXT NOT NULL +-- last_name TEXT NOT NULL + ,renew INT NOT NULL + ,refresh_token TEXT + ,customer_id TEXT NOT NULL + +-- Enciphered Passcode + ,code TEXT NOT NULL + ,mask TEXT NOT NULL + +-- Keypad Dimensions + ,attributes_per_key INT NOT NULL + ,number_of_keys INT NOT NULL + +-- User Keys + ,alpha_key BLOB NOT NULL + ,set_key BLOB NOT NULL + ,pass_key BLOB NOT NULL + ,mask_key BLOB NOT NULL + ,salt BLOB NOT NULL + ,max_nkode_len INT NOT NULL + +-- User Interface + ,idx_interface BLOB NOT NULL + ,svg_id_interface BLOB NOT NULL + + ,last_login TEXT NULL + ,created_at TEXT + + ,FOREIGN KEY (customer_id) REFERENCES customer(id) + ,UNIQUE(customer_id, email) +); + +CREATE TABLE IF NOT EXISTS svg_icon ( + id INTEGER PRIMARY KEY AUTOINCREMENT + ,svg TEXT NOT NULL +); From bf587792272d10e0829a8bcb0812a5a76c54a2ae Mon Sep 17 00:00:00 2001 From: Donovan Date: Wed, 4 Dec 2024 10:22:55 -0600 Subject: [PATCH 3/4] refactor sqlite db to support sqlc --- cmd/main.go | 9 +- internal/api/nkode_api.go | 2 +- internal/api/nkode_api_test.go | 5 +- internal/db/customer_user_repository.go | 2 +- internal/db/in_memory_db.go | 2 +- internal/db/sqlite_db.go | 663 +++++++++++------------- internal/db/sqlite_db_test.go | 7 +- internal/email/queue_test.go | 2 +- internal/entities/customer.go | 17 + internal/entities/user.go | 4 +- internal/models/models.go | 10 +- internal/utils/timestamp.go | 7 + 12 files changed, 342 insertions(+), 388 deletions(-) create mode 100644 internal/utils/timestamp.go diff --git a/cmd/main.go b/cmd/main.go index d04297d..534b1d3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -43,15 +43,18 @@ func main() { if dbPath == "" { log.Fatalf("SQLITE_DB=/path/to/nkode.db not set") } - db := db.NewSqliteDB(dbPath) - defer db.CloseDb() + sqlitedb, err := db.NewSqliteDB(dbPath) + if err != nil { + fmt.Errorf("%v", err) + } + defer sqlitedb.Close() sesClient := email.NewSESClient() emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) emailQueue.Start() defer emailQueue.Stop() - nkodeApi := api.NewNKodeAPI(db, emailQueue) + nkodeApi := api.NewNKodeAPI(sqlitedb, emailQueue) AddDefaultCustomer(nkodeApi) handler := api.NKodeHandler{Api: nkodeApi} diff --git a/internal/api/nkode_api.go b/internal/api/nkode_api.go index eeaf9c5..cf0f581 100644 --- a/internal/api/nkode_api.go +++ b/internal/api/nkode_api.go @@ -42,7 +42,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models. if err != nil { return nil, err } - err = n.Db.WriteNewCustomer(*newCustomer) + err = n.Db.CreateCustomer(*newCustomer) if err != nil { return nil, err diff --git a/internal/api/nkode_api_test.go b/internal/api/nkode_api_test.go index ebc217f..8200d1c 100644 --- a/internal/api/nkode_api_test.go +++ b/internal/api/nkode_api_test.go @@ -17,8 +17,9 @@ func TestNKodeAPI(t *testing.T) { dbFile := os.Getenv("TEST_DB") - db2 := db.NewSqliteDB(dbFile) - defer db2.CloseDb() + db2, err := db.NewSqliteDB(dbFile) + assert.NoError(t, err) + defer db2.Close() testNKodeAPI(t, db2) //if _, err := os.Stat(dbFile); err == nil { diff --git a/internal/db/customer_user_repository.go b/internal/db/customer_user_repository.go index 4cccc0b..3b59d51 100644 --- a/internal/db/customer_user_repository.go +++ b/internal/db/customer_user_repository.go @@ -8,7 +8,7 @@ import ( type CustomerUserRepository interface { GetCustomer(models.CustomerId) (*entities.Customer, error) GetUser(models.UserEmail, models.CustomerId) (*entities.User, error) - WriteNewCustomer(entities.Customer) error + CreateCustomer(entities.Customer) error WriteNewUser(entities.User) error UpdateUserNKode(entities.User) error UpdateUserInterface(models.UserId, entities.UserInterface) error diff --git a/internal/db/in_memory_db.go b/internal/db/in_memory_db.go index 05405fb..a73e413 100644 --- a/internal/db/in_memory_db.go +++ b/internal/db/in_memory_db.go @@ -42,7 +42,7 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo return &user, nil } -func (db *InMemoryDb) WriteNewCustomer(customer entities.Customer) error { +func (db *InMemoryDb) CreateCustomer(customer entities.Customer) error { _, exists := db.Customers[customer.Id] if exists { diff --git a/internal/db/sqlite_db.go b/internal/db/sqlite_db.go index 827995d..16b48b4 100644 --- a/internal/db/sqlite_db.go +++ b/internal/db/sqlite_db.go @@ -1,7 +1,9 @@ package db import ( + "context" "database/sql" + "errors" "fmt" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver @@ -9,432 +11,387 @@ import ( "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" + "go-nkode/internal/sqlc" + "go-nkode/internal/utils" "log" "sync" - "time" ) -type SqliteDB struct { - db *sql.DB - stop bool - writeQueue chan WriteTx - wg sync.WaitGroup -} +const writeBufferSize = 100 +type sqlcGeneric func(*sqlc.Queries, context.Context, any) error + +// WriteTx represents a write transaction type WriteTx struct { ErrChan chan error - Query string - Args []any + Query sqlcGeneric + Args interface{} } -const ( - writeBuffer = 1000 -) +// SqliteDB represents the SQLite database connection and write queue +type SqliteDB struct { + queries *sqlc.Queries + db *sql.DB + writeQueue chan WriteTx + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc +} + +// NewSqliteDB initializes a new SqliteDB instance +func NewSqliteDB(path string) (*SqliteDB, error) { + if path == "" { + return nil, errors.New("database path is required") + } -func NewSqliteDB(path string) *SqliteDB { db, err := sql.Open("sqlite3", path) if err != nil { - log.Fatal("database didn't open ", err) + return nil, fmt.Errorf("failed to open database: %w", err) } - sqldb := SqliteDB{ + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + sqldb := &SqliteDB{ + queries: sqlc.New(db), db: db, - stop: false, - writeQueue: make(chan WriteTx, writeBuffer), + writeQueue: make(chan WriteTx, writeBufferSize), + ctx: ctx, + cancel: cancel, } - go func() { - for writeTx := range sqldb.writeQueue { - writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args) - sqldb.wg.Done() + sqldb.wg.Add(1) + go sqldb.processWriteQueue() + + return sqldb, nil +} + +// processWriteQueue handles write transactions from the queue +func (d *SqliteDB) processWriteQueue() { + defer d.wg.Done() + for { + select { + case <-d.ctx.Done(): + return + case writeTx := <-d.writeQueue: + err := writeTx.Query(d.queries, d.ctx, writeTx.Args) + writeTx.ErrChan <- err } - }() - - return &sqldb + } } -func (d *SqliteDB) CloseDb() { - d.stop = true +func (d *SqliteDB) Close() error { + d.cancel() d.wg.Wait() - if err := d.db.Close(); err != nil { - // If db.Close() returns an error, panic - panic(fmt.Sprintf("Failed to close the database: %v", err)) - } + close(d.writeQueue) + return d.db.Close() } -func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error { - query := ` -INSERT INTO customer ( - id - ,max_nkode_len - ,min_nkode_len - ,distinct_sets - ,distinct_attributes - ,lock_out - ,expiration - ,attribute_values - ,set_values - ,last_renew - ,created_at -) -VALUES (?,?,?,?,?,?,?,?,?,?,?) -` - args := []any{ - uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, - c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, - c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(), +func (d *SqliteDB) CreateCustomer(c entities.Customer) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.CreateCustomerParams) + if !ok { + return fmt.Errorf("invalid argument type: expected CreateCustomerParams") + } + return q.CreateCustomer(ctx, params) } - return d.addWriteTx(query, args) + + return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) } func (d *SqliteDB) WriteNewUser(u entities.User) error { - query := ` -INSERT INTO user ( - id - ,email - ,renew - ,refresh_token - ,customer_id - ,code - ,mask - ,attributes_per_key - ,number_of_keys - ,alpha_key - ,set_key - ,pass_key - ,mask_key - ,salt - ,max_nkode_len - ,idx_interface - ,svg_id_interface - ,created_at -) -VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) -` - var renew int + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.CreateUserParams) + if !ok { + return fmt.Errorf("invalid argument type: expected CreateUserParams") + } + return q.CreateUser(ctx, params) + } + // Use the wrapped function in enqueueWriteTx + + renew := 0 if u.Renew { renew = 1 - } else { - renew = 0 } - - args := []any{ - uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), - u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, - security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), - security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), - u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), - security.IntArrToByteArr(u.Interface.SvgId), timeStamp(), + // Map entities.User to CreateUserParams + params := sqlc.CreateUserParams{ + ID: uuid.UUID(u.Id).String(), + Email: string(u.Email), + Renew: int64(renew), + RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""}, + CustomerID: uuid.UUID(u.CustomerId).String(), + Code: u.EncipheredPasscode.Code, + Mask: u.EncipheredPasscode.Mask, + AttributesPerKey: int64(u.Kp.AttrsPerKey), + NumberOfKeys: int64(u.Kp.NumbOfKeys), + AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), + SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey), + PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey), + MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), + Salt: u.CipherKeys.Salt, + MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen), + IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), + SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), + CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true}, } - - return d.addWriteTx(query, args) + return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserNKode(u entities.User) error { - query := ` -UPDATE user -SET renew = ? - ,refresh_token = ? - ,code = ? - ,mask = ? - ,attributes_per_key = ? - ,number_of_keys = ? - ,alpha_key = ? - ,set_key = ? - ,pass_key = ? - ,mask_key = ? - ,salt = ? - ,max_nkode_len = ? - ,idx_interface = ? - ,svg_id_interface = ? -WHERE email = ? AND customer_id = ? -` - var renew int + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.UpdateUserParams) + if !ok { + return fmt.Errorf("invalid argument type: expected UpdateUserParams") + } + return q.UpdateUser(ctx, params) + } + // Use the wrapped function in enqueueWriteTx + renew := 0 if u.Renew { renew = 1 - } else { - renew = 0 } - args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)} - - return d.addWriteTx(query, args) + params := sqlc.UpdateUserParams{ + Email: string(u.Email), + Renew: int64(renew), + RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""}, + CustomerID: uuid.UUID(u.CustomerId).String(), + Code: u.EncipheredPasscode.Code, + Mask: u.EncipheredPasscode.Mask, + AttributesPerKey: int64(u.Kp.AttrsPerKey), + NumberOfKeys: int64(u.Kp.NumbOfKeys), + AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), + SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey), + PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey), + MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), + Salt: u.CipherKeys.Salt, + MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen), + IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), + SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), + } + return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { - query := ` -UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ? -` - args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()} + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.UpdateUserInterfaceParams) + if !ok { + return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams") + } + return q.UpdateUserInterface(ctx, params) + } + params := sqlc.UpdateUserInterfaceParams{ + IdxInterface: security.IntArrToByteArr(ui.IdxInterface), + LastLogin: utils.TimeStamp(), + ID: uuid.UUID(id).String(), + } - return d.addWriteTx(query, args) + return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { - query := ` -UPDATE user SET refresh_token = ? WHERE id = ? -` - args := []any{refreshToken, uuid.UUID(id).String()} + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.UpdateUserRefreshTokenParams) + if !ok { + return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken") + } + return q.UpdateUserRefreshToken(ctx, params) + } + params := sqlc.UpdateUserRefreshTokenParams{ + RefreshToken: sql.NullString{ + String: refreshToken, + Valid: true, + }, + ID: uuid.UUID(id).String(), + } + return d.enqueueWriteTx(queryFunc, params) +} - return d.addWriteTx(query, args) +func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.RenewCustomerParams) + if !ok { + + } + return q.RenewCustomer(ctx, params) + } + return d.enqueueWriteTx(queryFunc, renewParams) } func (d *SqliteDB) Renew(id models.CustomerId) error { - // TODO: How long does a renew take? - customer, err := d.GetCustomer(id) + setXor, attrXor, err := d.renewCustomer(id) if err != nil { return err } - setXor, attrXor, err := customer.RenewKeys() + customerId := models.CustomerIdToString(id) + userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId) if err != nil { return err } - renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} - // TODO: replace with tx - renewQuery := ` -UPDATE customer -SET attribute_values = ?, set_values = ? -WHERE id = ?; -` - userQuery := ` -SELECT - id - ,alpha_key - ,set_key - ,attributes_per_key - ,number_of_keys -FROM user -WHERE customer_id = ? -` - tx, err := d.db.Begin() - if err != nil { - return err - } - rows, err := tx.Query(userQuery, uuid.UUID(id).String()) - for rows.Next() { - var userId string - var alphaBytes []byte - var setBytes []byte - var attrsPerKey int - var numbOfKeys int - err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) - if err != nil { - return err + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.RenewUserParams) + if !ok { + return fmt.Errorf("invalid argument type: expected RenewUserParams") } + return q.RenewUser(ctx, params) + } + + for _, row := range userRenewRows { user := entities.User{ - Id: models.UserId{}, + Id: models.UserIdFromString(row.ID), CustomerId: models.CustomerId{}, Email: "", EncipheredPasscode: models.EncipheredNKode{}, Kp: entities.KeypadDimension{ - AttrsPerKey: attrsPerKey, - NumbOfKeys: numbOfKeys, + AttrsPerKey: int(row.AttributesPerKey), + NumbOfKeys: int(row.NumberOfKeys), }, CipherKeys: entities.UserCipherKeys{ - AlphaKey: security.ByteArrToUint64Arr(alphaBytes), - SetKey: security.ByteArrToUint64Arr(setBytes), + AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey), + SetKey: security.ByteArrToUint64Arr(row.SetKey), }, Interface: entities.UserInterface{}, Renew: false, } - err = user.RenewKeys(setXor, attrXor) - if err != nil { + + if err = user.RenewKeys(setXor, attrXor); err != nil { + return err + } + params := sqlc.RenewUserParams{ + AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), + SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey), + Renew: 1, + ID: uuid.UUID(user.Id).String(), + } + if err = d.enqueueWriteTx(queryFunc, params); err != nil { return err } - renewQuery += ` -UPDATE user -SET alpha_key = ?, set_key = ?, renew = ? -WHERE id = ?; -` - renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) } - renewQuery += ` -` - err = tx.Commit() + return nil +} + +func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) { + customer, err := d.GetCustomer(id) if err != nil { - return err + return nil, nil, err } - return d.addWriteTx(renewQuery, renewArgs) + setXor, attrXor, err := customer.RenewKeys() + if err != nil { + return nil, nil, err + } + + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.RenewCustomerParams) + if !ok { + return fmt.Errorf("invalid argument type: expected RenewCustomerParams") + } + return q.RenewCustomer(ctx, params) + } + params := sqlc.RenewCustomerParams{ + AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals), + SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals), + ID: uuid.UUID(customer.Id).String(), + } + + if err = d.enqueueWriteTx(queryFunc, params); err != nil { + return nil, nil, err + } + return setXor, attrXor, nil } func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { - err := user.RefreshPasscode(passcodeIdx, customerAttr) - if err != nil { + if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil { return err } - query := ` -UPDATE user -SET - renew = ? - ,code = ? - ,mask = ? - ,alpha_key = ? - ,set_key = ? - ,pass_key = ? - ,mask_key = ? - ,salt = ? -WHERE id = ?; -` - args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()} - return d.addWriteTx(query, args) -} -func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { - tx, err := d.db.Begin() - if err != nil { - return nil, err - } - defer func() { - if err != nil { - err = tx.Rollback() - if err != nil { - log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) - } + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + params, ok := args.(sqlc.RefreshUserPasscodeParams) + if !ok { + return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams") } - }() - selectCustomer := ` -SELECT - max_nkode_len - ,min_nkode_len - ,distinct_sets - ,distinct_attributes - ,lock_out - ,expiration - ,attribute_values - ,set_values -FROM customer -WHERE id = ? -` - rows, err := tx.Query(selectCustomer, uuid.UUID(id)) + return q.RefreshUserPasscode(ctx, params) + } + params := sqlc.RefreshUserPasscodeParams{ + Renew: 0, + Code: user.EncipheredPasscode.Code, + Mask: user.EncipheredPasscode.Mask, + AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), + SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey), + PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey), + MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), + Salt: user.CipherKeys.Salt, + ID: uuid.UUID(user.Id).String(), + } + return d.enqueueWriteTx(queryFunc, params) +} + +func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { + customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String()) if err != nil { return nil, err } - if !rows.Next() { - log.Printf("no new row for customer %s with err %s", id, rows.Err()) - return nil, config.ErrCustomerDne - } - - var maxNKodeLen int - var minNKodeLen int - var distinctSets int - var distinctAttributes int - var lockOut int - var expiration int - var attributeValues []byte - var setValues []byte - err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues) - if err != nil { - return nil, err - } - customer := entities.Customer{ + return &entities.Customer{ Id: id, NKodePolicy: models.NKodePolicy{ - MaxNkodeLen: maxNKodeLen, - MinNkodeLen: minNKodeLen, - DistinctSets: distinctSets, - DistinctAttributes: distinctAttributes, - LockOut: lockOut, - Expiration: expiration, + MaxNkodeLen: int(customer.MaxNkodeLen), + MinNkodeLen: int(customer.MinNkodeLen), + DistinctSets: int(customer.DistinctSets), + DistinctAttributes: int(customer.DistinctAttributes), + LockOut: int(customer.LockOut), + Expiration: int(customer.Expiration), }, - Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues), - } - if err = tx.Commit(); err != nil { - return nil, err - } - return &customer, nil + Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues), + }, nil } func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { - tx, err := d.db.Begin() + userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{ + Email: string(email), + CustomerID: uuid.UUID(customerId).String(), + }) if err != nil { - return nil, err + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to get user: %w", err) } - userSelect := ` -SELECT - id - ,renew - ,refresh_token - ,code - ,mask - ,attributes_per_key - ,number_of_keys - ,alpha_key - ,set_key - ,pass_key - ,mask_key - ,salt - ,max_nkode_len - ,idx_interface - ,svg_id_interface -FROM user -WHERE user.email = ? AND user.customer_id = ? -` - rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String()) - if !rows.Next() { - return nil, nil - } - var ( - id string - renewVal int - refreshToken string - code string - mask string - attrsPerKey int - numbOfKeys int - alphaKey []byte - setKey []byte - passKey []byte - maskKey []byte - salt []byte - maxNKodeLen int - idxInterface []byte - svgIdInterface []byte - ) - err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) - userId, err := uuid.Parse(id) - if err != nil { - return nil, err + kp := entities.KeypadDimension{ + AttrsPerKey: int(userRow.AttributesPerKey), + NumbOfKeys: int(userRow.NumberOfKeys), } - var renew bool - if renewVal == 0 { - renew = false - } else { + + renew := false + if userRow.Renew == 1 { renew = true } - user := entities.User{ - Id: models.UserId(userId), + Id: models.UserIdFromString(userRow.ID), CustomerId: customerId, Email: email, EncipheredPasscode: models.EncipheredNKode{ - Code: code, - Mask: mask, - }, - Kp: entities.KeypadDimension{ - AttrsPerKey: attrsPerKey, - NumbOfKeys: numbOfKeys, + Code: userRow.Code, + Mask: userRow.Mask, }, + Kp: kp, CipherKeys: entities.UserCipherKeys{ - AlphaKey: security.ByteArrToUint64Arr(alphaKey), - SetKey: security.ByteArrToUint64Arr(setKey), - PassKey: security.ByteArrToUint64Arr(passKey), - MaskKey: security.ByteArrToUint64Arr(maskKey), - Salt: salt, - MaxNKodeLen: maxNKodeLen, - Kp: nil, + AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey), + SetKey: security.ByteArrToUint64Arr(userRow.SetKey), + PassKey: security.ByteArrToUint64Arr(userRow.PassKey), + MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey), + Salt: userRow.Salt, + MaxNKodeLen: int(userRow.MaxNkodeLen), + Kp: &kp, }, Interface: entities.UserInterface{ - IdxInterface: security.ByteArrToIntArr(idxInterface), - SvgId: security.ByteArrToIntArr(svgIdInterface), - Kp: nil, + IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface), + SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface), + Kp: &kp, }, Renew: renew, - RefreshToken: refreshToken, - } - user.Interface.Kp = &user.Kp - user.CipherKeys.Kp = &user.Kp - if err = tx.Commit(); err != nil { - return nil, err + RefreshToken: userRow.RefreshToken.String, } return &user, nil } @@ -456,68 +413,30 @@ func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, } func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { - tx, err := d.db.Begin() - if err != nil { - return nil, err - } - selectId := ` -SELECT svg -FROM svg_icon -WHERE id = ? -` svgs := make([]string, len(ids)) for idx, id := range ids { - rows, err := tx.Query(selectId, id) + svg, err := d.queries.GetSvgId(d.ctx, int64(id)) if err != nil { return nil, err } - if !rows.Next() { - log.Printf("id not found: %d", id) - return nil, config.ErrSvgDne - } - if err = rows.Scan(&svgs[idx]); err != nil { - return nil, err - } - } - if err = tx.Commit(); err != nil { - return nil, err + svgs[idx] = svg } return svgs, nil } -func (d *SqliteDB) writeToDb(query string, args []any) error { - tx, err := d.db.Begin() - if err != nil { - return err +func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error { + select { + case <-d.ctx.Done(): + return errors.New("database is shutting down") + default: } - defer func() { - if err != nil { - err = tx.Rollback() - if err != nil { - log.Fatalf("fatal error: write won't roll back %+v", err) - } - } - }() - if _, err = tx.Exec(query, args...); err != nil { - return err - } - if err = tx.Commit(); err != nil { - return err - } - return nil -} -func (d *SqliteDB) addWriteTx(query string, args []any) error { - if d.stop { - return config.ErrStoppingDatabase - } - errChan := make(chan error) + errChan := make(chan error, 1) writeTx := WriteTx{ - Query: query, + Query: queryFunc, Args: args, ErrChan: errChan, } - d.wg.Add(1) d.writeQueue <- writeTx return <-errChan } @@ -559,7 +478,3 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) { return perm[:count], nil } - -func timeStamp() string { - return time.Now().Format(time.RFC3339) -} diff --git a/internal/db/sqlite_db_test.go b/internal/db/sqlite_db_test.go index 9f53ef1..0c6d11e 100644 --- a/internal/db/sqlite_db_test.go +++ b/internal/db/sqlite_db_test.go @@ -11,8 +11,9 @@ import ( func TestNewSqliteDB(t *testing.T) { dbFile := os.Getenv("TEST_DB") // sql_driver.MakeTables(dbFile) - db := NewSqliteDB(dbFile) - defer db.CloseDb() + db, err := NewSqliteDB(dbFile) + assert.NoError(t, err) + defer db.Close() testSignupLoginRenew(t, db) testSqliteDBRandomSvgInterface(t, db) @@ -28,7 +29,7 @@ func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { nkodePolicy := models.NewDefaultNKodePolicy() customerOrig, err := entities.NewCustomer(nkodePolicy) assert.NoError(t, err) - err = db.WriteNewCustomer(*customerOrig) + err = db.CreateCustomer(*customerOrig) assert.NoError(t, err) customer, err := db.GetCustomer(customerOrig.Id) assert.NoError(t, err) diff --git a/internal/email/queue_test.go b/internal/email/queue_test.go index 525f06a..a0733d0 100644 --- a/internal/email/queue_test.go +++ b/internal/email/queue_test.go @@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) { } queue.AddEmail(email) } - // CloseDb the queue after all emails are processed + // Close the queue after all emails are processed queue.Stop() assert.Equal(t, queue.FailedSendCount, 0) diff --git a/internal/entities/customer.go b/internal/entities/customer.go index a111cc4..5ae8fad 100644 --- a/internal/entities/customer.go +++ b/internal/entities/customer.go @@ -5,6 +5,7 @@ import ( "go-nkode/config" "go-nkode/internal/models" "go-nkode/internal/security" + "go-nkode/internal/sqlc" "go-nkode/internal/utils" ) @@ -83,3 +84,19 @@ func (c *Customer) RenewKeys() ([]uint64, []uint64, error) { } return setXor, attrsXor, nil } + +func (c *Customer) ToSqlcCreateCustomerParams() sqlc.CreateCustomerParams { + return sqlc.CreateCustomerParams{ + ID: uuid.UUID(c.Id).String(), + MaxNkodeLen: int64(c.NKodePolicy.MaxNkodeLen), + MinNkodeLen: int64(c.NKodePolicy.MinNkodeLen), + DistinctSets: int64(c.NKodePolicy.DistinctSets), + DistinctAttributes: int64(c.NKodePolicy.DistinctAttributes), + LockOut: int64(c.NKodePolicy.LockOut), + Expiration: int64(c.NKodePolicy.Expiration), + AttributeValues: c.Attributes.AttrBytes(), + SetValues: c.Attributes.SetBytes(), + LastRenew: utils.TimeStamp(), + CreatedAt: utils.TimeStamp(), + } +} diff --git a/internal/entities/user.go b/internal/entities/user.go index ec89efe..a4ae96f 100644 --- a/internal/entities/user.go +++ b/internal/entities/user.go @@ -37,11 +37,13 @@ func (u *User) RenewKeys(setXor []uint64, attrXor []uint64) error { func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error { setVals, err := customerAttributes.SetValsForKp(u.Kp) + if err != nil { + return err + } newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen) if err != nil { return err } - encipheredPasscode, err := newKeys.EncipherNKode(passcodeAttrIdx, customerAttributes) if err != nil { return err diff --git a/internal/models/models.go b/internal/models/models.go index 87a8832..49863da 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -1,6 +1,7 @@ package models import ( + "fmt" "github.com/google/uuid" "net/mail" "strings" @@ -99,10 +100,17 @@ func CustomerIdToString(customerId CustomerId) string { type SessionId uuid.UUID type UserId uuid.UUID +func UserIdFromString(userId string) UserId { + id, err := uuid.Parse(userId) + if err != nil { + fmt.Errorf("unable to parse user id %+v", err) + } + return UserId(id) +} + func (s *SessionId) String() string { id := uuid.UUID(*s) return id.String() - } type UserEmail string diff --git a/internal/utils/timestamp.go b/internal/utils/timestamp.go new file mode 100644 index 0000000..4de32dc --- /dev/null +++ b/internal/utils/timestamp.go @@ -0,0 +1,7 @@ +package utils + +import "time" + +func TimeStamp() string { + return time.Now().Format(time.RFC3339) +} From 0c5de93a0da243d33be1dee87c2d018be55b4183 Mon Sep 17 00:00:00 2001 From: Donovan Date: Wed, 4 Dec 2024 10:39:09 -0600 Subject: [PATCH 4/4] chagne db failed connection to fatal --- cmd/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/main.go b/cmd/main.go index 534b1d3..c119843 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -45,7 +45,7 @@ func main() { } sqlitedb, err := db.NewSqliteDB(dbPath) if err != nil { - fmt.Errorf("%v", err) + log.Fatalf("%v", err) } defer sqlitedb.Close()