diff --git a/core/email_queue.go b/core/email_queue.go index aa8ec54..57fc116 100644 --- a/core/email_queue.go +++ b/core/email_queue.go @@ -51,6 +51,7 @@ func NewSESClient() SESClient { } func (s *SESClient) SendEmail(email Email) error { + if _, exists := s.ResetCache.Get(email.Recipient); exists { return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject) } @@ -102,6 +103,7 @@ func (s *SESClient) SendEmail(email Email) error { // EmailQueue represents the email queue with rate limiting type EmailQueue struct { + stop bool emailQueue chan Email // Email queue rateLimit <-chan time.Time // Rate limiter client EmailClient // SES client to send emails @@ -115,6 +117,7 @@ func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *Ema rateLimit := time.Tick(time.Second / time.Duration(emailsPerSecond)) return &EmailQueue{ + stop: false, emailQueue: make(chan Email, bufferSize), rateLimit: rateLimit, client: client, @@ -124,12 +127,17 @@ func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *Ema // AddEmail queues a new email to be sent func (q *EmailQueue) AddEmail(email Email) { + if q.stop { + log.Printf("email %s with subject %s not add. Stopping queue", email.Recipient, email.Subject) + return + } q.wg.Add(1) q.emailQueue <- email } // Start begins processing the email queue with rate limiting func (q *EmailQueue) Start() { + q.stop = false // Worker goroutine that processes emails from the queue go func() { for email := range q.emailQueue { @@ -150,6 +158,7 @@ func (q *EmailQueue) sendEmail(email Email) { // Stop stops the queue after all emails have been processed func (q *EmailQueue) Stop() { + q.stop = true // Wait for all emails to be processed q.wg.Wait() // Close the email queue diff --git a/core/email_queue_test.go b/core/email_queue_test.go index 43aa4a1..ff0ab49 100644 --- a/core/email_queue_test.go +++ b/core/email_queue_test.go @@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) { } queue.AddEmail(email) } - // Stop the queue after all emails are processed + // CloseDb the queue after all emails are processed queue.Stop() assert.Equal(t, queue.FailedSendCount, 0) diff --git a/core/nkode_api.go b/core/nkode_api.go index e7df9bb..10e863d 100644 --- a/core/nkode_api.go +++ b/core/nkode_api.go @@ -38,6 +38,13 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy NKodePolicy, id *CustomerId) (* } func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId CustomerId, kp KeypadDimension, reset bool) (*GenerateSignupResetInterfaceResp, error) { + user, err := n.Db.GetUser(userEmail, customerId) + if err != nil { + return nil, err + } + if user != nil && !reset { + return nil, fmt.Errorf("user %s already exists", string(userEmail)) + } svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) if err != nil { return nil, err diff --git a/core/nkode_api_test.go b/core/nkode_api_test.go index 2d82750..0313a9e 100644 --- a/core/nkode_api_test.go +++ b/core/nkode_api_test.go @@ -3,6 +3,7 @@ package core import ( "github.com/stretchr/testify/assert" "go-nkode/util" + "os" "testing" ) @@ -10,7 +11,7 @@ func TestNKodeAPI(t *testing.T) { //db1 := NewInMemoryDb() //testNKodeAPI(t, &db1) - dbFile := "../test.db" + dbFile := os.Getenv("TEST_DB") db2 := NewSqliteDB(dbFile) defer db2.CloseDb() @@ -101,5 +102,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) { assert.NoError(t, err) _, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection) assert.NoError(t, err) + signupResponse, err = nkodeApi.GenerateSignupResetInterface(userEmail, *customerId, keypadSize, false) + assert.Error(t, err) } } diff --git a/core/sqlite-init/sqlite_init.go b/core/sqlite-init/sqlite_init.go index 6070714..9850478 100644 --- a/core/sqlite-init/sqlite_init.go +++ b/core/sqlite-init/sqlite_init.go @@ -24,7 +24,7 @@ type Root struct { } func main() { - dbPaths := []string{"test.db", "nkode.db"} + dbPaths := []string{"/Users/donov/databases/test.db", "/Users/donov/databases/nkode.db"} outputStr := MakeSvgFiles() for _, path := range dbPaths { MakeTables(path) @@ -57,7 +57,7 @@ VALUES (?) } func MakeSvgFiles() string { - jsonFiles, err := GetAllFiles("./core//sqlite-init/json") + jsonFiles, err := GetAllFiles("./core/sqlite-init/json") if err != nil { log.Fatalf("Error getting JSON files: %v", err) } @@ -146,7 +146,7 @@ func MakeTables(dbPath string) { PRAGMA journal_mode=WAL; --PRAGMA busy_timeout = 5000; -- Wait up to 5 seconds --PRAGMA synchronous = NORMAL; -- Reduce sync frequency for less locking ---PRAGMA cache_size = -16000; -- Increase cache size (16MB)PRAGMA foreign_keys = ON; +--PRAGMA cache_size = -16000; -- Increase cache size (16MB)PRAGMA CREATE TABLE IF NOT EXISTS customer ( id TEXT NOT NULL PRIMARY KEY, diff --git a/core/sqlite_db.go b/core/sqlite_db.go index 496ad4f..170fd98 100644 --- a/core/sqlite_db.go +++ b/core/sqlite_db.go @@ -8,23 +8,50 @@ import ( _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "go-nkode/util" "log" + "sync" ) type SqliteDB struct { - db *sql.DB + db *sql.DB + stop bool + writeQueue chan WriteTx + wg sync.WaitGroup } +type WriteTx struct { + ErrChan chan error + Query string + Args []any +} + +const ( + writeBuffer = 1000 +) + func NewSqliteDB(path string) *SqliteDB { db, err := sql.Open("sqlite3", path) if err != nil { log.Fatal("database didn't open ", err) } - sqldb := SqliteDB{db: db} + sqldb := SqliteDB{ + db: db, + stop: false, + writeQueue: make(chan WriteTx, writeBuffer), + } + + go func() { + for writeTx := range sqldb.writeQueue { + writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args) + sqldb.wg.Done() + } + }() return &sqldb } func (d *SqliteDB) CloseDb() { + d.stop = true + d.wg.Wait() if err := d.db.Close(); err != nil { // If db.Close() returns an error, panic panic(fmt.Sprintf("Failed to close the database: %v", err)) @@ -32,48 +59,16 @@ func (d *SqliteDB) CloseDb() { } func (d *SqliteDB) WriteNewCustomer(c Customer) error { - tx, err := d.db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - err = tx.Rollback() - if err != nil { - log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err)) - } - } - }() - insertCustomer := ` + query := ` INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values) VALUES (?,?,?,?,?,?,?,?,?) ` - _, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) - if err != nil { - return err - } - err = tx.Commit() - if err != nil { - return err - } - return nil + args := []any{uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()} + return d.addWriteTx(query, args) } func (d *SqliteDB) WriteNewUser(u User) error { - - tx, err := d.db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - err = tx.Rollback() - if err != nil { - log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) - } - } - }() - insertUser := ` + query := ` INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` @@ -83,32 +78,14 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) } else { renew = 0 } - _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) - if err != nil { - return err - } - err = tx.Commit() - if err != nil { - return err - } - return nil + args := []any{uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)} + + return d.addWriteTx(query, args) } func (d *SqliteDB) UpdateUserNKode(u User) error { - tx, err := d.db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - err = tx.Rollback() - if err != nil { - log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) - } - } - }() - updateUser := ` + query := ` UPDATE user SET renew = ?, refresh_token = ?, code = ?, mask = ?, attributes_per_key = ?, number_of_keys = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ?, max_nkode_len = ?, idx_interface = ?, svg_id_interface = ? WHERE username = ? AND customer_id = ? @@ -119,18 +96,103 @@ WHERE username = ? AND customer_id = ? } else { renew = 0 } - _, err = tx.Exec(updateUser, renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)) + args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)} + return d.addWriteTx(query, args) +} + +func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error { + query := ` +UPDATE user SET idx_interface = ? WHERE id = ? +` + args := []any{util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()} + + return d.addWriteTx(query, args) +} + +func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error { + query := ` +UPDATE user SET refresh_token = ? WHERE id = ? +` + args := []any{refreshToken, uuid.UUID(id).String()} + + return d.addWriteTx(query, args) +} + +func (d *SqliteDB) Renew(id CustomerId) error { + // TODO: How long does a renew take? + customer, err := d.GetCustomer(id) if err != nil { return err } + setXor, attrXor := customer.RenewKeys() + renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} + // TODO: replace with tx + renewQuery := ` +UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; +` + + userQuery := ` +SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ? +` + tx, err := d.db.Begin() + if err != nil { + return err + } + rows, err := tx.Query(userQuery, uuid.UUID(id).String()) + for rows.Next() { + var userId string + var alphaBytes []byte + var setBytes []byte + var attrsPerKey int + var numbOfKeys int + err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) + if err != nil { + return err + } + user := User{ + Id: UserId{}, + CustomerId: CustomerId{}, + Email: "", + EncipheredPasscode: EncipheredNKode{}, + Kp: KeypadDimension{ + AttrsPerKey: attrsPerKey, + NumbOfKeys: numbOfKeys, + }, + CipherKeys: UserCipherKeys{ + AlphaKey: util.ByteArrToUint64Arr(alphaBytes), + SetKey: util.ByteArrToUint64Arr(setBytes), + }, + Interface: UserInterface{}, + Renew: false, + } + err = user.RenewKeys(setXor, attrXor) + if err != nil { + return err + } + renewQuery += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;" + renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) + } + renewQuery += ` +` err = tx.Commit() if err != nil { return err } - return nil + return d.addWriteTx(renewQuery, renewArgs) } +func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { + err := user.RefreshPasscode(passcodeIdx, customerAttr) + if err != nil { + return err + } + query := ` +UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; +` + args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()} + return d.addWriteTx(query, args) +} func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { tx, err := d.db.Begin() if err != nil { @@ -191,7 +253,6 @@ func (d *SqliteDB) GetUser(username UserEmail, customerId CustomerId) (*User, er if err != nil { return nil, err } - defer tx.Commit() userSelect := ` SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user WHERE user.username = ? AND user.customer_id = ? @@ -263,115 +324,13 @@ WHERE user.username = ? AND user.customer_id = ? } user.Interface.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp - + err = tx.Commit() + if err != nil { + return nil, err + } return &user, nil } -func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error { - updateUserInterface := ` -UPDATE user SET idx_interface = ? WHERE id = ? -` - _, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()) - - return err -} - -func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error { - updateUserRefreshToken := ` -UPDATE user SET refresh_token = ? WHERE id = ? -` - _, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String()) - - return err -} - -func (d *SqliteDB) Renew(id CustomerId) error { - customer, err := d.GetCustomer(id) - if err != nil { - return err - } - setXor, attrXor := customer.RenewKeys() - renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} - // TODO: replace with tx - renewExec := ` -BEGIN TRANSACTION; - -UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; -` - - userQuery := ` -SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ? -` - tx, err := d.db.Begin() - if err != nil { - return err - } - rows, err := tx.Query(userQuery, uuid.UUID(id).String()) - for rows.Next() { - var userId string - var alphaBytes []byte - var setBytes []byte - var attrsPerKey int - var numbOfKeys int - err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) - if err != nil { - return err - } - user := User{ - Id: UserId{}, - CustomerId: CustomerId{}, - Email: "", - EncipheredPasscode: EncipheredNKode{}, - Kp: KeypadDimension{ - AttrsPerKey: attrsPerKey, - NumbOfKeys: numbOfKeys, - }, - CipherKeys: UserCipherKeys{ - AlphaKey: util.ByteArrToUint64Arr(alphaBytes), - SetKey: util.ByteArrToUint64Arr(setBytes), - }, - Interface: UserInterface{}, - Renew: false, - } - err = user.RenewKeys(setXor, attrXor) - if err != nil { - return err - } - renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;" - renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) - } - renewExec += ` -COMMIT; -` - err = tx.Commit() - if err != nil { - return err - } - - tx, err = d.db.Begin() - if err != nil { - return err - } - _, err = d.db.Exec(renewExec, renewArgs...) - err = tx.Commit() - if err != nil { - return err - } - return err -} - -func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { - err := user.RefreshPasscode(passcodeIdx, customerAttr) - if err != nil { - return err - } - updateUser := ` -UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; -` - _, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) - return err -} - func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { @@ -393,7 +352,6 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { if err != nil { return nil, err } - defer tx.Commit() selectId := "SELECT svg FROM svg_icon where id = ?" svgs := make([]string, len(ids)) for idx, id := range ids { @@ -409,15 +367,57 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { return nil, err } } + err = tx.Commit() + if err != nil { + return nil, err + } return svgs, nil } +func (d *SqliteDB) writeToDb(query string, args []any) error { + tx, err := d.db.Begin() + if err != nil { + return err + } + defer func() { + if err != nil { + err = tx.Rollback() + if err != nil { + log.Fatal(fmt.Sprintf("Write won't roll back %+v", err)) + } + } + }() + _, err = tx.Exec(query, args...) + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + return nil +} + +func (d *SqliteDB) addWriteTx(query string, args []any) error { + if d.stop { + return errors.New("stopping database") + } + errChan := make(chan error) + writeTx := WriteTx{ + Query: query, + Args: args, + ErrChan: errChan, + } + d.wg.Add(1) + d.writeQueue <- writeTx + return <-errChan +} + func (d *SqliteDB) getRandomIds(count int) ([]int, error) { tx, err := d.db.Begin() if err != nil { return nil, err } - defer tx.Commit() rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") if err != nil { return nil, err @@ -426,15 +426,20 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) { if !rows.Next() { return nil, errors.New("empty svg_icon table") } - err = rows.Scan(&tableLen) + + if err = rows.Scan(&tableLen); err != nil { + return nil, err + } + + perm, err := util.RandomPermutation(tableLen) if err != nil { return nil, err } - perm, err := util.RandomPermutation(tableLen) for idx := range perm { perm[idx] += 1 } - if err != nil { + + if err = tx.Commit(); err != nil { return nil, err } return perm[:count], nil diff --git a/core/sqlite_db_test.go b/core/sqlite_db_test.go index a9fac67..eabfe23 100644 --- a/core/sqlite_db_test.go +++ b/core/sqlite_db_test.go @@ -2,11 +2,12 @@ package core import ( "github.com/stretchr/testify/assert" + "os" "testing" ) func TestNewSqliteDB(t *testing.T) { - dbFile := "../test.db" + dbFile := os.Getenv("TEST_DB") // sql_driver.MakeTables(dbFile) db := NewSqliteDB(dbFile) defer db.CloseDb() diff --git a/core/type.go b/core/type.go index f7debdc..372a318 100644 --- a/core/type.go +++ b/core/type.go @@ -3,6 +3,7 @@ package core import ( "github.com/google/uuid" "net/mail" + "strings" ) type SetNKodeResp struct { @@ -101,7 +102,7 @@ func ParseEmail(email string) (UserEmail, error) { if err != nil { return "", err } - return UserEmail(email), err + return UserEmail(strings.ToLower(email)), err } diff --git a/main.go b/main.go index ecbf4e4..742829a 100644 --- a/main.go +++ b/main.go @@ -6,15 +6,20 @@ import ( "go-nkode/core" "log" "net/http" + "os" ) const ( emailQueueBufferSize = 100 - maxEmailsPerSecond = 13 // SES allows 14 but I don't want to push it + maxEmailsPerSecond = 13 // SES allows 14, but I don't want to push it ) func main() { - db := core.NewSqliteDB("nkode.db") + dbPath := os.Getenv("SQLITE_DB") + if dbPath == "" { + log.Fatalf("SQLITE_DB=/path/to/nkode.db not set") + } + db := core.NewSqliteDB(dbPath) defer db.CloseDb() sesClient := core.NewSESClient() emailQueue := core.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) diff --git a/main_test.go b/main_test.go index a510ee5..f7460ee 100644 --- a/main_test.go +++ b/main_test.go @@ -6,8 +6,10 @@ import ( "fmt" "github.com/stretchr/testify/assert" "go-nkode/core" + "go-nkode/util" "io" "net/http" + "strings" "testing" ) @@ -23,12 +25,12 @@ func TestApi(t *testing.T) { var customerResp core.CreateNewCustomerResp testApiPost(t, base+core.CreateNewCustomer, newCustomerBody, &customerResp) - username := "test_username@example.com" + userEmail := "test_username" + util.GenerateRandomString(12) + "@example.com" signupInterfaceBody := core.GenerateSignupRestInterfacePost{ CustomerId: customerResp.CustomerId, AttrsPerKey: kp.AttrsPerKey, NumbOfKeys: kp.NumbOfKeys, - UserEmail: username, + UserEmail: strings.ToUpper(userEmail), // should be case-insensitive Reset: false, } var signupInterfaceResp core.GenerateSignupResetInterfaceResp @@ -59,7 +61,7 @@ func TestApi(t *testing.T) { loginInterfaceBody := core.GetLoginInterfacePost{ CustomerId: customerResp.CustomerId, - UserEmail: username, + UserEmail: userEmail, } var loginInterfaceResp core.GetLoginInterfaceResp @@ -70,16 +72,16 @@ func TestApi(t *testing.T) { assert.NoError(t, err) loginBody := core.LoginPost{ CustomerId: customerResp.CustomerId, - UserEmail: username, + UserEmail: userEmail, KeySelection: loginKeySelection, } var jwtTokens core.AuthenticationTokens testApiPost(t, base+core.Login, loginBody, &jwtTokens) refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken) - assert.Equal(t, refreshClaims.Subject, username) + assert.Equal(t, refreshClaims.Subject, userEmail) accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken) - assert.Equal(t, accessClaims.Subject, username) + assert.Equal(t, accessClaims.Subject, userEmail) renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} testApiPost(t, base+core.RenewAttributes, renewBody, nil) @@ -87,7 +89,7 @@ func TestApi(t *testing.T) { assert.NoError(t, err) loginBody = core.LoginPost{ CustomerId: customerResp.CustomerId, - UserEmail: username, + UserEmail: userEmail, KeySelection: loginKeySelection, } @@ -102,7 +104,7 @@ func TestApi(t *testing.T) { testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken) accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken) assert.NoError(t, err) - assert.Equal(t, accessClaims.Subject, username) + assert.Equal(t, accessClaims.Subject, userEmail) } func Unmarshal(t *testing.T, resp *http.Response, data any) { diff --git a/util/util.go b/util/util.go index 39dca6e..208fa43 100644 --- a/util/util.go +++ b/util/util.go @@ -255,7 +255,7 @@ func Choice[T any](items []T) T { // GenerateRandomString creates a random string of a specified length. func GenerateRandomString(length int) string { - charset := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + charset := []rune("abcdefghijklmnopqrstuvwxyz0123456789") b := make([]rune, length) for i := range b { b[i] = Choice[rune](charset)