implement sqlite write queue

This commit is contained in:
2024-10-10 15:01:45 -05:00
parent 3574d07997
commit 1e33a81a2c
11 changed files with 225 additions and 192 deletions

View File

@@ -51,6 +51,7 @@ func NewSESClient() SESClient {
} }
func (s *SESClient) SendEmail(email Email) error { func (s *SESClient) SendEmail(email Email) error {
if _, exists := s.ResetCache.Get(email.Recipient); exists { if _, exists := s.ResetCache.Get(email.Recipient); exists {
return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject) return fmt.Errorf("email already sent to %s with subject %s", email.Recipient, email.Subject)
} }
@@ -102,6 +103,7 @@ func (s *SESClient) SendEmail(email Email) error {
// EmailQueue represents the email queue with rate limiting // EmailQueue represents the email queue with rate limiting
type EmailQueue struct { type EmailQueue struct {
stop bool
emailQueue chan Email // Email queue emailQueue chan Email // Email queue
rateLimit <-chan time.Time // Rate limiter rateLimit <-chan time.Time // Rate limiter
client EmailClient // SES client to send emails client EmailClient // SES client to send emails
@@ -115,6 +117,7 @@ func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *Ema
rateLimit := time.Tick(time.Second / time.Duration(emailsPerSecond)) rateLimit := time.Tick(time.Second / time.Duration(emailsPerSecond))
return &EmailQueue{ return &EmailQueue{
stop: false,
emailQueue: make(chan Email, bufferSize), emailQueue: make(chan Email, bufferSize),
rateLimit: rateLimit, rateLimit: rateLimit,
client: client, client: client,
@@ -124,12 +127,17 @@ func NewEmailQueue(bufferSize int, emailsPerSecond int, client EmailClient) *Ema
// AddEmail queues a new email to be sent // AddEmail queues a new email to be sent
func (q *EmailQueue) AddEmail(email Email) { func (q *EmailQueue) AddEmail(email Email) {
if q.stop {
log.Printf("email %s with subject %s not add. Stopping queue", email.Recipient, email.Subject)
return
}
q.wg.Add(1) q.wg.Add(1)
q.emailQueue <- email q.emailQueue <- email
} }
// Start begins processing the email queue with rate limiting // Start begins processing the email queue with rate limiting
func (q *EmailQueue) Start() { func (q *EmailQueue) Start() {
q.stop = false
// Worker goroutine that processes emails from the queue // Worker goroutine that processes emails from the queue
go func() { go func() {
for email := range q.emailQueue { for email := range q.emailQueue {
@@ -150,6 +158,7 @@ func (q *EmailQueue) sendEmail(email Email) {
// Stop stops the queue after all emails have been processed // Stop stops the queue after all emails have been processed
func (q *EmailQueue) Stop() { func (q *EmailQueue) Stop() {
q.stop = true
// Wait for all emails to be processed // Wait for all emails to be processed
q.wg.Wait() q.wg.Wait()
// Close the email queue // Close the email queue

View File

@@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) {
} }
queue.AddEmail(email) queue.AddEmail(email)
} }
// Stop the queue after all emails are processed // CloseDb the queue after all emails are processed
queue.Stop() queue.Stop()
assert.Equal(t, queue.FailedSendCount, 0) assert.Equal(t, queue.FailedSendCount, 0)

View File

@@ -38,6 +38,13 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy NKodePolicy, id *CustomerId) (*
} }
func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId CustomerId, kp KeypadDimension, reset bool) (*GenerateSignupResetInterfaceResp, error) { func (n *NKodeAPI) GenerateSignupResetInterface(userEmail UserEmail, customerId CustomerId, kp KeypadDimension, reset bool) (*GenerateSignupResetInterfaceResp, error) {
user, err := n.Db.GetUser(userEmail, customerId)
if err != nil {
return nil, err
}
if user != nil && !reset {
return nil, fmt.Errorf("user %s already exists", string(userEmail))
}
svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -3,6 +3,7 @@ package core
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go-nkode/util" "go-nkode/util"
"os"
"testing" "testing"
) )
@@ -10,7 +11,7 @@ func TestNKodeAPI(t *testing.T) {
//db1 := NewInMemoryDb() //db1 := NewInMemoryDb()
//testNKodeAPI(t, &db1) //testNKodeAPI(t, &db1)
dbFile := "../test.db" dbFile := os.Getenv("TEST_DB")
db2 := NewSqliteDB(dbFile) db2 := NewSqliteDB(dbFile)
defer db2.CloseDb() defer db2.CloseDb()
@@ -101,5 +102,7 @@ func testNKodeAPI(t *testing.T, db DbAccessor) {
assert.NoError(t, err) assert.NoError(t, err)
_, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection) _, err = nkodeApi.Login(*customerId, userEmail, loginKeySelection)
assert.NoError(t, err) assert.NoError(t, err)
signupResponse, err = nkodeApi.GenerateSignupResetInterface(userEmail, *customerId, keypadSize, false)
assert.Error(t, err)
} }
} }

View File

@@ -24,7 +24,7 @@ type Root struct {
} }
func main() { func main() {
dbPaths := []string{"test.db", "nkode.db"} dbPaths := []string{"/Users/donov/databases/test.db", "/Users/donov/databases/nkode.db"}
outputStr := MakeSvgFiles() outputStr := MakeSvgFiles()
for _, path := range dbPaths { for _, path := range dbPaths {
MakeTables(path) MakeTables(path)
@@ -57,7 +57,7 @@ VALUES (?)
} }
func MakeSvgFiles() string { func MakeSvgFiles() string {
jsonFiles, err := GetAllFiles("./core//sqlite-init/json") jsonFiles, err := GetAllFiles("./core/sqlite-init/json")
if err != nil { if err != nil {
log.Fatalf("Error getting JSON files: %v", err) log.Fatalf("Error getting JSON files: %v", err)
} }
@@ -146,7 +146,7 @@ func MakeTables(dbPath string) {
PRAGMA journal_mode=WAL; PRAGMA journal_mode=WAL;
--PRAGMA busy_timeout = 5000; -- Wait up to 5 seconds --PRAGMA busy_timeout = 5000; -- Wait up to 5 seconds
--PRAGMA synchronous = NORMAL; -- Reduce sync frequency for less locking --PRAGMA synchronous = NORMAL; -- Reduce sync frequency for less locking
--PRAGMA cache_size = -16000; -- Increase cache size (16MB)PRAGMA foreign_keys = ON; --PRAGMA cache_size = -16000; -- Increase cache size (16MB)PRAGMA
CREATE TABLE IF NOT EXISTS customer ( CREATE TABLE IF NOT EXISTS customer (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,

View File

@@ -8,23 +8,50 @@ import (
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
"go-nkode/util" "go-nkode/util"
"log" "log"
"sync"
) )
type SqliteDB struct { type SqliteDB struct {
db *sql.DB db *sql.DB
stop bool
writeQueue chan WriteTx
wg sync.WaitGroup
} }
type WriteTx struct {
ErrChan chan error
Query string
Args []any
}
const (
writeBuffer = 1000
)
func NewSqliteDB(path string) *SqliteDB { func NewSqliteDB(path string) *SqliteDB {
db, err := sql.Open("sqlite3", path) db, err := sql.Open("sqlite3", path)
if err != nil { if err != nil {
log.Fatal("database didn't open ", err) log.Fatal("database didn't open ", err)
} }
sqldb := SqliteDB{db: db} sqldb := SqliteDB{
db: db,
stop: false,
writeQueue: make(chan WriteTx, writeBuffer),
}
go func() {
for writeTx := range sqldb.writeQueue {
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
sqldb.wg.Done()
}
}()
return &sqldb return &sqldb
} }
func (d *SqliteDB) CloseDb() { func (d *SqliteDB) CloseDb() {
d.stop = true
d.wg.Wait()
if err := d.db.Close(); err != nil { if err := d.db.Close(); err != nil {
// If db.Close() returns an error, panic // If db.Close() returns an error, panic
panic(fmt.Sprintf("Failed to close the database: %v", err)) panic(fmt.Sprintf("Failed to close the database: %v", err))
@@ -32,48 +59,16 @@ func (d *SqliteDB) CloseDb() {
} }
func (d *SqliteDB) WriteNewCustomer(c Customer) error { func (d *SqliteDB) WriteNewCustomer(c Customer) error {
tx, err := d.db.Begin() query := `
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err))
}
}
}()
insertCustomer := `
INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values) INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values)
VALUES (?,?,?,?,?,?,?,?,?) VALUES (?,?,?,?,?,?,?,?,?)
` `
_, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) args := []any{uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()}
if err != nil { return d.addWriteTx(query, args)
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
} }
func (d *SqliteDB) WriteNewUser(u User) error { func (d *SqliteDB) WriteNewUser(u User) error {
query := `
tx, err := d.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
}
}
}()
insertUser := `
INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
` `
@@ -83,32 +78,14 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
} else { } else {
renew = 0 renew = 0
} }
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
if err != nil { args := []any{uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)}
return err
} return d.addWriteTx(query, args)
err = tx.Commit()
if err != nil {
return err
}
return nil
} }
func (d *SqliteDB) UpdateUserNKode(u User) error { func (d *SqliteDB) UpdateUserNKode(u User) error {
tx, err := d.db.Begin() query := `
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
}
}
}()
updateUser := `
UPDATE user UPDATE user
SET renew = ?, refresh_token = ?, code = ?, mask = ?, attributes_per_key = ?, number_of_keys = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ?, max_nkode_len = ?, idx_interface = ?, svg_id_interface = ? SET renew = ?, refresh_token = ?, code = ?, mask = ?, attributes_per_key = ?, number_of_keys = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ?, max_nkode_len = ?, idx_interface = ?, svg_id_interface = ?
WHERE username = ? AND customer_id = ? WHERE username = ? AND customer_id = ?
@@ -119,18 +96,103 @@ WHERE username = ? AND customer_id = ?
} else { } else {
renew = 0 renew = 0
} }
_, err = tx.Exec(updateUser, renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)) args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error {
query := `
UPDATE user SET idx_interface = ? WHERE id = ?
`
args := []any{util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
query := `
UPDATE user SET refresh_token = ? WHERE id = ?
`
args := []any{refreshToken, uuid.UUID(id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) Renew(id CustomerId) error {
// TODO: How long does a renew take?
customer, err := d.GetCustomer(id)
if err != nil { if err != nil {
return err return err
} }
setXor, attrXor := customer.RenewKeys()
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx
renewQuery := `
UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?;
`
userQuery := `
SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ?
`
tx, err := d.db.Begin()
if err != nil {
return err
}
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
for rows.Next() {
var userId string
var alphaBytes []byte
var setBytes []byte
var attrsPerKey int
var numbOfKeys int
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
if err != nil {
return err
}
user := User{
Id: UserId{},
CustomerId: CustomerId{},
Email: "",
EncipheredPasscode: EncipheredNKode{},
Kp: KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: UserCipherKeys{
AlphaKey: util.ByteArrToUint64Arr(alphaBytes),
SetKey: util.ByteArrToUint64Arr(setBytes),
},
Interface: UserInterface{},
Renew: false,
}
err = user.RenewKeys(setXor, attrXor)
if err != nil {
return err
}
renewQuery += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;"
renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
}
renewQuery += `
`
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return err return err
} }
return nil return d.addWriteTx(renewQuery, renewArgs)
} }
func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil {
return err
}
query := `
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
`
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
tx, err := d.db.Begin() tx, err := d.db.Begin()
if err != nil { if err != nil {
@@ -191,7 +253,6 @@ func (d *SqliteDB) GetUser(username UserEmail, customerId CustomerId) (*User, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Commit()
userSelect := ` userSelect := `
SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
WHERE user.username = ? AND user.customer_id = ? WHERE user.username = ? AND user.customer_id = ?
@@ -263,115 +324,13 @@ WHERE user.username = ? AND user.customer_id = ?
} }
user.Interface.Kp = &user.Kp user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp
err = tx.Commit()
if err != nil {
return nil, err
}
return &user, nil return &user, nil
} }
func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error {
updateUserInterface := `
UPDATE user SET idx_interface = ? WHERE id = ?
`
_, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String())
return err
}
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
updateUserRefreshToken := `
UPDATE user SET refresh_token = ? WHERE id = ?
`
_, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String())
return err
}
func (d *SqliteDB) Renew(id CustomerId) error {
customer, err := d.GetCustomer(id)
if err != nil {
return err
}
setXor, attrXor := customer.RenewKeys()
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx
renewExec := `
BEGIN TRANSACTION;
UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?;
`
userQuery := `
SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ?
`
tx, err := d.db.Begin()
if err != nil {
return err
}
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
for rows.Next() {
var userId string
var alphaBytes []byte
var setBytes []byte
var attrsPerKey int
var numbOfKeys int
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
if err != nil {
return err
}
user := User{
Id: UserId{},
CustomerId: CustomerId{},
Email: "",
EncipheredPasscode: EncipheredNKode{},
Kp: KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: UserCipherKeys{
AlphaKey: util.ByteArrToUint64Arr(alphaBytes),
SetKey: util.ByteArrToUint64Arr(setBytes),
},
Interface: UserInterface{},
Renew: false,
}
err = user.RenewKeys(setXor, attrXor)
if err != nil {
return err
}
renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;"
renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
}
renewExec += `
COMMIT;
`
err = tx.Commit()
if err != nil {
return err
}
tx, err = d.db.Begin()
if err != nil {
return err
}
_, err = d.db.Exec(renewExec, renewArgs...)
err = tx.Commit()
if err != nil {
return err
}
return err
}
func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil {
return err
}
updateUser := `
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
`
_, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
return err
}
func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) { func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) {
ids, err := d.getRandomIds(kp.TotalAttrs()) ids, err := d.getRandomIds(kp.TotalAttrs())
if err != nil { if err != nil {
@@ -393,7 +352,6 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Commit()
selectId := "SELECT svg FROM svg_icon where id = ?" selectId := "SELECT svg FROM svg_icon where id = ?"
svgs := make([]string, len(ids)) svgs := make([]string, len(ids))
for idx, id := range ids { for idx, id := range ids {
@@ -409,15 +367,57 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
return nil, err return nil, err
} }
} }
err = tx.Commit()
if err != nil {
return nil, err
}
return svgs, nil return svgs, nil
} }
func (d *SqliteDB) writeToDb(query string, args []any) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write won't roll back %+v", err))
}
}
}()
_, err = tx.Exec(query, args...)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (d *SqliteDB) addWriteTx(query string, args []any) error {
if d.stop {
return errors.New("stopping database")
}
errChan := make(chan error)
writeTx := WriteTx{
Query: query,
Args: args,
ErrChan: errChan,
}
d.wg.Add(1)
d.writeQueue <- writeTx
return <-errChan
}
func (d *SqliteDB) getRandomIds(count int) ([]int, error) { func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
tx, err := d.db.Begin() tx, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Commit()
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
if err != nil { if err != nil {
return nil, err return nil, err
@@ -426,15 +426,20 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
if !rows.Next() { if !rows.Next() {
return nil, errors.New("empty svg_icon table") return nil, errors.New("empty svg_icon table")
} }
err = rows.Scan(&tableLen)
if err = rows.Scan(&tableLen); err != nil {
return nil, err
}
perm, err := util.RandomPermutation(tableLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
perm, err := util.RandomPermutation(tableLen)
for idx := range perm { for idx := range perm {
perm[idx] += 1 perm[idx] += 1
} }
if err != nil {
if err = tx.Commit(); err != nil {
return nil, err return nil, err
} }
return perm[:count], nil return perm[:count], nil

View File

@@ -2,11 +2,12 @@ package core
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"os"
"testing" "testing"
) )
func TestNewSqliteDB(t *testing.T) { func TestNewSqliteDB(t *testing.T) {
dbFile := "../test.db" dbFile := os.Getenv("TEST_DB")
// sql_driver.MakeTables(dbFile) // sql_driver.MakeTables(dbFile)
db := NewSqliteDB(dbFile) db := NewSqliteDB(dbFile)
defer db.CloseDb() defer db.CloseDb()

View File

@@ -3,6 +3,7 @@ package core
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"net/mail" "net/mail"
"strings"
) )
type SetNKodeResp struct { type SetNKodeResp struct {
@@ -101,7 +102,7 @@ func ParseEmail(email string) (UserEmail, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return UserEmail(email), err return UserEmail(strings.ToLower(email)), err
} }

View File

@@ -6,15 +6,20 @@ import (
"go-nkode/core" "go-nkode/core"
"log" "log"
"net/http" "net/http"
"os"
) )
const ( const (
emailQueueBufferSize = 100 emailQueueBufferSize = 100
maxEmailsPerSecond = 13 // SES allows 14 but I don't want to push it maxEmailsPerSecond = 13 // SES allows 14, but I don't want to push it
) )
func main() { func main() {
db := core.NewSqliteDB("nkode.db") dbPath := os.Getenv("SQLITE_DB")
if dbPath == "" {
log.Fatalf("SQLITE_DB=/path/to/nkode.db not set")
}
db := core.NewSqliteDB(dbPath)
defer db.CloseDb() defer db.CloseDb()
sesClient := core.NewSESClient() sesClient := core.NewSESClient()
emailQueue := core.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) emailQueue := core.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient)

View File

@@ -6,8 +6,10 @@ import (
"fmt" "fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go-nkode/core" "go-nkode/core"
"go-nkode/util"
"io" "io"
"net/http" "net/http"
"strings"
"testing" "testing"
) )
@@ -23,12 +25,12 @@ func TestApi(t *testing.T) {
var customerResp core.CreateNewCustomerResp var customerResp core.CreateNewCustomerResp
testApiPost(t, base+core.CreateNewCustomer, newCustomerBody, &customerResp) testApiPost(t, base+core.CreateNewCustomer, newCustomerBody, &customerResp)
username := "test_username@example.com" userEmail := "test_username" + util.GenerateRandomString(12) + "@example.com"
signupInterfaceBody := core.GenerateSignupRestInterfacePost{ signupInterfaceBody := core.GenerateSignupRestInterfacePost{
CustomerId: customerResp.CustomerId, CustomerId: customerResp.CustomerId,
AttrsPerKey: kp.AttrsPerKey, AttrsPerKey: kp.AttrsPerKey,
NumbOfKeys: kp.NumbOfKeys, NumbOfKeys: kp.NumbOfKeys,
UserEmail: username, UserEmail: strings.ToUpper(userEmail), // should be case-insensitive
Reset: false, Reset: false,
} }
var signupInterfaceResp core.GenerateSignupResetInterfaceResp var signupInterfaceResp core.GenerateSignupResetInterfaceResp
@@ -59,7 +61,7 @@ func TestApi(t *testing.T) {
loginInterfaceBody := core.GetLoginInterfacePost{ loginInterfaceBody := core.GetLoginInterfacePost{
CustomerId: customerResp.CustomerId, CustomerId: customerResp.CustomerId,
UserEmail: username, UserEmail: userEmail,
} }
var loginInterfaceResp core.GetLoginInterfaceResp var loginInterfaceResp core.GetLoginInterfaceResp
@@ -70,16 +72,16 @@ func TestApi(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
loginBody := core.LoginPost{ loginBody := core.LoginPost{
CustomerId: customerResp.CustomerId, CustomerId: customerResp.CustomerId,
UserEmail: username, UserEmail: userEmail,
KeySelection: loginKeySelection, KeySelection: loginKeySelection,
} }
var jwtTokens core.AuthenticationTokens var jwtTokens core.AuthenticationTokens
testApiPost(t, base+core.Login, loginBody, &jwtTokens) testApiPost(t, base+core.Login, loginBody, &jwtTokens)
refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken) refreshClaims, err := core.ParseRefreshToken(jwtTokens.RefreshToken)
assert.Equal(t, refreshClaims.Subject, username) assert.Equal(t, refreshClaims.Subject, userEmail)
accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken) accessClaims, err := core.ParseRefreshToken(jwtTokens.AccessToken)
assert.Equal(t, accessClaims.Subject, username) assert.Equal(t, accessClaims.Subject, userEmail)
renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId} renewBody := core.RenewAttributesPost{CustomerId: customerResp.CustomerId}
testApiPost(t, base+core.RenewAttributes, renewBody, nil) testApiPost(t, base+core.RenewAttributes, renewBody, nil)
@@ -87,7 +89,7 @@ func TestApi(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
loginBody = core.LoginPost{ loginBody = core.LoginPost{
CustomerId: customerResp.CustomerId, CustomerId: customerResp.CustomerId,
UserEmail: username, UserEmail: userEmail,
KeySelection: loginKeySelection, KeySelection: loginKeySelection,
} }
@@ -102,7 +104,7 @@ func TestApi(t *testing.T) {
testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken) testApiGet(t, base+core.RefreshToken, &refreshTokenResp, jwtTokens.RefreshToken)
accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken) accessClaims, err = core.ParseAccessToken(refreshTokenResp.AccessToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, accessClaims.Subject, username) assert.Equal(t, accessClaims.Subject, userEmail)
} }
func Unmarshal(t *testing.T, resp *http.Response, data any) { func Unmarshal(t *testing.T, resp *http.Response, data any) {

View File

@@ -255,7 +255,7 @@ func Choice[T any](items []T) T {
// GenerateRandomString creates a random string of a specified length. // GenerateRandomString creates a random string of a specified length.
func GenerateRandomString(length int) string { func GenerateRandomString(length int) string {
charset := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") charset := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
b := make([]rune, length) b := make([]rune, length)
for i := range b { for i := range b {
b[i] = Choice[rune](charset) b[i] = Choice[rune](charset)