refactor sqlite db to support sqlc

This commit is contained in:
2024-12-04 10:22:55 -06:00
parent 69ec9bd08c
commit bf58779227
12 changed files with 342 additions and 388 deletions

View File

@@ -43,15 +43,18 @@ func main() {
if dbPath == "" { if dbPath == "" {
log.Fatalf("SQLITE_DB=/path/to/nkode.db not set") log.Fatalf("SQLITE_DB=/path/to/nkode.db not set")
} }
db := db.NewSqliteDB(dbPath) sqlitedb, err := db.NewSqliteDB(dbPath)
defer db.CloseDb() if err != nil {
fmt.Errorf("%v", err)
}
defer sqlitedb.Close()
sesClient := email.NewSESClient() sesClient := email.NewSESClient()
emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient)
emailQueue.Start() emailQueue.Start()
defer emailQueue.Stop() defer emailQueue.Stop()
nkodeApi := api.NewNKodeAPI(db, emailQueue) nkodeApi := api.NewNKodeAPI(sqlitedb, emailQueue)
AddDefaultCustomer(nkodeApi) AddDefaultCustomer(nkodeApi)
handler := api.NKodeHandler{Api: nkodeApi} handler := api.NKodeHandler{Api: nkodeApi}

View File

@@ -42,7 +42,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models.
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = n.Db.WriteNewCustomer(*newCustomer) err = n.Db.CreateCustomer(*newCustomer)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -17,8 +17,9 @@ func TestNKodeAPI(t *testing.T) {
dbFile := os.Getenv("TEST_DB") dbFile := os.Getenv("TEST_DB")
db2 := db.NewSqliteDB(dbFile) db2, err := db.NewSqliteDB(dbFile)
defer db2.CloseDb() assert.NoError(t, err)
defer db2.Close()
testNKodeAPI(t, db2) testNKodeAPI(t, db2)
//if _, err := os.Stat(dbFile); err == nil { //if _, err := os.Stat(dbFile); err == nil {

View File

@@ -8,7 +8,7 @@ import (
type CustomerUserRepository interface { type CustomerUserRepository interface {
GetCustomer(models.CustomerId) (*entities.Customer, error) GetCustomer(models.CustomerId) (*entities.Customer, error)
GetUser(models.UserEmail, models.CustomerId) (*entities.User, error) GetUser(models.UserEmail, models.CustomerId) (*entities.User, error)
WriteNewCustomer(entities.Customer) error CreateCustomer(entities.Customer) error
WriteNewUser(entities.User) error WriteNewUser(entities.User) error
UpdateUserNKode(entities.User) error UpdateUserNKode(entities.User) error
UpdateUserInterface(models.UserId, entities.UserInterface) error UpdateUserInterface(models.UserId, entities.UserInterface) error

View File

@@ -42,7 +42,7 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo
return &user, nil return &user, nil
} }
func (db *InMemoryDb) WriteNewCustomer(customer entities.Customer) error { func (db *InMemoryDb) CreateCustomer(customer entities.Customer) error {
_, exists := db.Customers[customer.Id] _, exists := db.Customers[customer.Id]
if exists { if exists {

View File

@@ -1,7 +1,9 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
@@ -9,432 +11,387 @@ import (
"go-nkode/internal/entities" "go-nkode/internal/entities"
"go-nkode/internal/models" "go-nkode/internal/models"
"go-nkode/internal/security" "go-nkode/internal/security"
"go-nkode/internal/sqlc"
"go-nkode/internal/utils"
"log" "log"
"sync" "sync"
"time"
) )
type SqliteDB struct { const writeBufferSize = 100
db *sql.DB
stop bool
writeQueue chan WriteTx
wg sync.WaitGroup
}
type sqlcGeneric func(*sqlc.Queries, context.Context, any) error
// WriteTx represents a write transaction
type WriteTx struct { type WriteTx struct {
ErrChan chan error ErrChan chan error
Query string Query sqlcGeneric
Args []any Args interface{}
} }
const ( // SqliteDB represents the SQLite database connection and write queue
writeBuffer = 1000 type SqliteDB struct {
) queries *sqlc.Queries
db *sql.DB
writeQueue chan WriteTx
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewSqliteDB initializes a new SqliteDB instance
func NewSqliteDB(path string) (*SqliteDB, error) {
if path == "" {
return nil, errors.New("database path is required")
}
func NewSqliteDB(path string) *SqliteDB {
db, err := sql.Open("sqlite3", path) db, err := sql.Open("sqlite3", path)
if err != nil { if err != nil {
log.Fatal("database didn't open ", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
sqldb := SqliteDB{
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
sqldb := &SqliteDB{
queries: sqlc.New(db),
db: db, db: db,
stop: false, writeQueue: make(chan WriteTx, writeBufferSize),
writeQueue: make(chan WriteTx, writeBuffer), ctx: ctx,
cancel: cancel,
} }
go func() { sqldb.wg.Add(1)
for writeTx := range sqldb.writeQueue { go sqldb.processWriteQueue()
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
sqldb.wg.Done() return sqldb, nil
}
// processWriteQueue handles write transactions from the queue
func (d *SqliteDB) processWriteQueue() {
defer d.wg.Done()
for {
select {
case <-d.ctx.Done():
return
case writeTx := <-d.writeQueue:
err := writeTx.Query(d.queries, d.ctx, writeTx.Args)
writeTx.ErrChan <- err
} }
}() }
return &sqldb
} }
func (d *SqliteDB) CloseDb() { func (d *SqliteDB) Close() error {
d.stop = true d.cancel()
d.wg.Wait() d.wg.Wait()
if err := d.db.Close(); err != nil { close(d.writeQueue)
// If db.Close() returns an error, panic return d.db.Close()
panic(fmt.Sprintf("Failed to close the database: %v", err))
}
} }
func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error { func (d *SqliteDB) CreateCustomer(c entities.Customer) error {
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
INSERT INTO customer ( params, ok := args.(sqlc.CreateCustomerParams)
id if !ok {
,max_nkode_len return fmt.Errorf("invalid argument type: expected CreateCustomerParams")
,min_nkode_len }
,distinct_sets return q.CreateCustomer(ctx, params)
,distinct_attributes
,lock_out
,expiration
,attribute_values
,set_values
,last_renew
,created_at
)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
`
args := []any{
uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets,
c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration,
c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(),
} }
return d.addWriteTx(query, args)
return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams())
} }
func (d *SqliteDB) WriteNewUser(u entities.User) error { func (d *SqliteDB) WriteNewUser(u entities.User) error {
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
INSERT INTO user ( params, ok := args.(sqlc.CreateUserParams)
id if !ok {
,email return fmt.Errorf("invalid argument type: expected CreateUserParams")
,renew }
,refresh_token return q.CreateUser(ctx, params)
,customer_id }
,code // Use the wrapped function in enqueueWriteTx
,mask
,attributes_per_key renew := 0
,number_of_keys
,alpha_key
,set_key
,pass_key
,mask_key
,salt
,max_nkode_len
,idx_interface
,svg_id_interface
,created_at
)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
var renew int
if u.Renew { if u.Renew {
renew = 1 renew = 1
} else {
renew = 0
} }
// Map entities.User to CreateUserParams
args := []any{ params := sqlc.CreateUserParams{
uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), ID: uuid.UUID(u.Id).String(),
u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, Email: string(u.Email),
security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), Renew: int64(renew),
security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), CustomerID: uuid.UUID(u.CustomerId).String(),
security.IntArrToByteArr(u.Interface.SvgId), timeStamp(), Code: u.EncipheredPasscode.Code,
Mask: u.EncipheredPasscode.Mask,
AttributesPerKey: int64(u.Kp.AttrsPerKey),
NumberOfKeys: int64(u.Kp.NumbOfKeys),
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
Salt: u.CipherKeys.Salt,
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true},
} }
return d.enqueueWriteTx(queryFunc, params)
return d.addWriteTx(query, args)
} }
func (d *SqliteDB) UpdateUserNKode(u entities.User) error { func (d *SqliteDB) UpdateUserNKode(u entities.User) error {
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
UPDATE user params, ok := args.(sqlc.UpdateUserParams)
SET renew = ? if !ok {
,refresh_token = ? return fmt.Errorf("invalid argument type: expected UpdateUserParams")
,code = ? }
,mask = ? return q.UpdateUser(ctx, params)
,attributes_per_key = ? }
,number_of_keys = ? // Use the wrapped function in enqueueWriteTx
,alpha_key = ? renew := 0
,set_key = ?
,pass_key = ?
,mask_key = ?
,salt = ?
,max_nkode_len = ?
,idx_interface = ?
,svg_id_interface = ?
WHERE email = ? AND customer_id = ?
`
var renew int
if u.Renew { if u.Renew {
renew = 1 renew = 1
} else {
renew = 0
} }
args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)} params := sqlc.UpdateUserParams{
Email: string(u.Email),
return d.addWriteTx(query, args) Renew: int64(renew),
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
CustomerID: uuid.UUID(u.CustomerId).String(),
Code: u.EncipheredPasscode.Code,
Mask: u.EncipheredPasscode.Mask,
AttributesPerKey: int64(u.Kp.AttrsPerKey),
NumberOfKeys: int64(u.Kp.NumbOfKeys),
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
Salt: u.CipherKeys.Salt,
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
}
return d.enqueueWriteTx(queryFunc, params)
} }
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error {
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ? params, ok := args.(sqlc.UpdateUserInterfaceParams)
` if !ok {
args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()} return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams")
}
return q.UpdateUserInterface(ctx, params)
}
params := sqlc.UpdateUserInterfaceParams{
IdxInterface: security.IntArrToByteArr(ui.IdxInterface),
LastLogin: utils.TimeStamp(),
ID: uuid.UUID(id).String(),
}
return d.addWriteTx(query, args) return d.enqueueWriteTx(queryFunc, params)
} }
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
UPDATE user SET refresh_token = ? WHERE id = ? params, ok := args.(sqlc.UpdateUserRefreshTokenParams)
` if !ok {
args := []any{refreshToken, uuid.UUID(id).String()} return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken")
}
return q.UpdateUserRefreshToken(ctx, params)
}
params := sqlc.UpdateUserRefreshTokenParams{
RefreshToken: sql.NullString{
String: refreshToken,
Valid: true,
},
ID: uuid.UUID(id).String(),
}
return d.enqueueWriteTx(queryFunc, params)
}
return d.addWriteTx(query, args) func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RenewCustomerParams)
if !ok {
}
return q.RenewCustomer(ctx, params)
}
return d.enqueueWriteTx(queryFunc, renewParams)
} }
func (d *SqliteDB) Renew(id models.CustomerId) error { func (d *SqliteDB) Renew(id models.CustomerId) error {
// TODO: How long does a renew take? setXor, attrXor, err := d.renewCustomer(id)
customer, err := d.GetCustomer(id)
if err != nil { if err != nil {
return err return err
} }
setXor, attrXor, err := customer.RenewKeys() customerId := models.CustomerIdToString(id)
userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId)
if err != nil { if err != nil {
return err return err
} }
renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx
renewQuery := `
UPDATE customer
SET attribute_values = ?, set_values = ?
WHERE id = ?;
`
userQuery := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
SELECT params, ok := args.(sqlc.RenewUserParams)
id if !ok {
,alpha_key return fmt.Errorf("invalid argument type: expected RenewUserParams")
,set_key
,attributes_per_key
,number_of_keys
FROM user
WHERE customer_id = ?
`
tx, err := d.db.Begin()
if err != nil {
return err
}
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
for rows.Next() {
var userId string
var alphaBytes []byte
var setBytes []byte
var attrsPerKey int
var numbOfKeys int
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
if err != nil {
return err
} }
return q.RenewUser(ctx, params)
}
for _, row := range userRenewRows {
user := entities.User{ user := entities.User{
Id: models.UserId{}, Id: models.UserIdFromString(row.ID),
CustomerId: models.CustomerId{}, CustomerId: models.CustomerId{},
Email: "", Email: "",
EncipheredPasscode: models.EncipheredNKode{}, EncipheredPasscode: models.EncipheredNKode{},
Kp: entities.KeypadDimension{ Kp: entities.KeypadDimension{
AttrsPerKey: attrsPerKey, AttrsPerKey: int(row.AttributesPerKey),
NumbOfKeys: numbOfKeys, NumbOfKeys: int(row.NumberOfKeys),
}, },
CipherKeys: entities.UserCipherKeys{ CipherKeys: entities.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(alphaBytes), AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey),
SetKey: security.ByteArrToUint64Arr(setBytes), SetKey: security.ByteArrToUint64Arr(row.SetKey),
}, },
Interface: entities.UserInterface{}, Interface: entities.UserInterface{},
Renew: false, Renew: false,
} }
err = user.RenewKeys(setXor, attrXor)
if err != nil { if err = user.RenewKeys(setXor, attrXor); err != nil {
return err
}
params := sqlc.RenewUserParams{
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
Renew: 1,
ID: uuid.UUID(user.Id).String(),
}
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
return err return err
} }
renewQuery += `
UPDATE user
SET alpha_key = ?, set_key = ?, renew = ?
WHERE id = ?;
`
renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
} }
renewQuery += ` return nil
` }
err = tx.Commit()
func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) {
customer, err := d.GetCustomer(id)
if err != nil { if err != nil {
return err return nil, nil, err
} }
return d.addWriteTx(renewQuery, renewArgs) setXor, attrXor, err := customer.RenewKeys()
if err != nil {
return nil, nil, err
}
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RenewCustomerParams)
if !ok {
return fmt.Errorf("invalid argument type: expected RenewCustomerParams")
}
return q.RenewCustomer(ctx, params)
}
params := sqlc.RenewCustomerParams{
AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals),
SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals),
ID: uuid.UUID(customer.Id).String(),
}
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
return nil, nil, err
}
return setXor, attrXor, nil
} }
func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr) if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil {
if err != nil {
return err return err
} }
query := ` queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
UPDATE user params, ok := args.(sqlc.RefreshUserPasscodeParams)
SET if !ok {
renew = ? return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams")
,code = ?
,mask = ?
,alpha_key = ?
,set_key = ?
,pass_key = ?
,mask_key = ?
,salt = ?
WHERE id = ?;
`
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
}
} }
}() return q.RefreshUserPasscode(ctx, params)
selectCustomer := ` }
SELECT params := sqlc.RefreshUserPasscodeParams{
max_nkode_len Renew: 0,
,min_nkode_len Code: user.EncipheredPasscode.Code,
,distinct_sets Mask: user.EncipheredPasscode.Mask,
,distinct_attributes AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
,lock_out SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
,expiration PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey),
,attribute_values MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey),
,set_values Salt: user.CipherKeys.Salt,
FROM customer ID: uuid.UUID(user.Id).String(),
WHERE id = ? }
` return d.enqueueWriteTx(queryFunc, params)
rows, err := tx.Query(selectCustomer, uuid.UUID(id)) }
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !rows.Next() { return &entities.Customer{
log.Printf("no new row for customer %s with err %s", id, rows.Err())
return nil, config.ErrCustomerDne
}
var maxNKodeLen int
var minNKodeLen int
var distinctSets int
var distinctAttributes int
var lockOut int
var expiration int
var attributeValues []byte
var setValues []byte
err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues)
if err != nil {
return nil, err
}
customer := entities.Customer{
Id: id, Id: id,
NKodePolicy: models.NKodePolicy{ NKodePolicy: models.NKodePolicy{
MaxNkodeLen: maxNKodeLen, MaxNkodeLen: int(customer.MaxNkodeLen),
MinNkodeLen: minNKodeLen, MinNkodeLen: int(customer.MinNkodeLen),
DistinctSets: distinctSets, DistinctSets: int(customer.DistinctSets),
DistinctAttributes: distinctAttributes, DistinctAttributes: int(customer.DistinctAttributes),
LockOut: lockOut, LockOut: int(customer.LockOut),
Expiration: expiration, Expiration: int(customer.Expiration),
}, },
Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues), Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues),
} }, nil
if err = tx.Commit(); err != nil {
return nil, err
}
return &customer, nil
} }
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
tx, err := d.db.Begin() userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{
Email: string(email),
CustomerID: uuid.UUID(customerId).String(),
})
if err != nil { if err != nil {
return nil, err if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to get user: %w", err)
} }
userSelect := `
SELECT
id
,renew
,refresh_token
,code
,mask
,attributes_per_key
,number_of_keys
,alpha_key
,set_key
,pass_key
,mask_key
,salt
,max_nkode_len
,idx_interface
,svg_id_interface
FROM user
WHERE user.email = ? AND user.customer_id = ?
`
rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String())
if !rows.Next() {
return nil, nil
}
var (
id string
renewVal int
refreshToken string
code string
mask string
attrsPerKey int
numbOfKeys int
alphaKey []byte
setKey []byte
passKey []byte
maskKey []byte
salt []byte
maxNKodeLen int
idxInterface []byte
svgIdInterface []byte
)
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
userId, err := uuid.Parse(id) kp := entities.KeypadDimension{
if err != nil { AttrsPerKey: int(userRow.AttributesPerKey),
return nil, err NumbOfKeys: int(userRow.NumberOfKeys),
} }
var renew bool
if renewVal == 0 { renew := false
renew = false if userRow.Renew == 1 {
} else {
renew = true renew = true
} }
user := entities.User{ user := entities.User{
Id: models.UserId(userId), Id: models.UserIdFromString(userRow.ID),
CustomerId: customerId, CustomerId: customerId,
Email: email, Email: email,
EncipheredPasscode: models.EncipheredNKode{ EncipheredPasscode: models.EncipheredNKode{
Code: code, Code: userRow.Code,
Mask: mask, Mask: userRow.Mask,
},
Kp: entities.KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
}, },
Kp: kp,
CipherKeys: entities.UserCipherKeys{ CipherKeys: entities.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(alphaKey), AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey),
SetKey: security.ByteArrToUint64Arr(setKey), SetKey: security.ByteArrToUint64Arr(userRow.SetKey),
PassKey: security.ByteArrToUint64Arr(passKey), PassKey: security.ByteArrToUint64Arr(userRow.PassKey),
MaskKey: security.ByteArrToUint64Arr(maskKey), MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey),
Salt: salt, Salt: userRow.Salt,
MaxNKodeLen: maxNKodeLen, MaxNKodeLen: int(userRow.MaxNkodeLen),
Kp: nil, Kp: &kp,
}, },
Interface: entities.UserInterface{ Interface: entities.UserInterface{
IdxInterface: security.ByteArrToIntArr(idxInterface), IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface),
SvgId: security.ByteArrToIntArr(svgIdInterface), SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface),
Kp: nil, Kp: &kp,
}, },
Renew: renew, Renew: renew,
RefreshToken: refreshToken, RefreshToken: userRow.RefreshToken.String,
}
user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp
if err = tx.Commit(); err != nil {
return nil, err
} }
return &user, nil return &user, nil
} }
@@ -456,68 +413,30 @@ func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string,
} }
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
selectId := `
SELECT svg
FROM svg_icon
WHERE id = ?
`
svgs := make([]string, len(ids)) svgs := make([]string, len(ids))
for idx, id := range ids { for idx, id := range ids {
rows, err := tx.Query(selectId, id) svg, err := d.queries.GetSvgId(d.ctx, int64(id))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !rows.Next() { svgs[idx] = svg
log.Printf("id not found: %d", id)
return nil, config.ErrSvgDne
}
if err = rows.Scan(&svgs[idx]); err != nil {
return nil, err
}
}
if err = tx.Commit(); err != nil {
return nil, err
} }
return svgs, nil return svgs, nil
} }
func (d *SqliteDB) writeToDb(query string, args []any) error { func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error {
tx, err := d.db.Begin() select {
if err != nil { case <-d.ctx.Done():
return err return errors.New("database is shutting down")
default:
} }
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatalf("fatal error: write won't roll back %+v", err)
}
}
}()
if _, err = tx.Exec(query, args...); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
func (d *SqliteDB) addWriteTx(query string, args []any) error { errChan := make(chan error, 1)
if d.stop {
return config.ErrStoppingDatabase
}
errChan := make(chan error)
writeTx := WriteTx{ writeTx := WriteTx{
Query: query, Query: queryFunc,
Args: args, Args: args,
ErrChan: errChan, ErrChan: errChan,
} }
d.wg.Add(1)
d.writeQueue <- writeTx d.writeQueue <- writeTx
return <-errChan return <-errChan
} }
@@ -559,7 +478,3 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
return perm[:count], nil return perm[:count], nil
} }
func timeStamp() string {
return time.Now().Format(time.RFC3339)
}

View File

@@ -11,8 +11,9 @@ import (
func TestNewSqliteDB(t *testing.T) { func TestNewSqliteDB(t *testing.T) {
dbFile := os.Getenv("TEST_DB") dbFile := os.Getenv("TEST_DB")
// sql_driver.MakeTables(dbFile) // sql_driver.MakeTables(dbFile)
db := NewSqliteDB(dbFile) db, err := NewSqliteDB(dbFile)
defer db.CloseDb() assert.NoError(t, err)
defer db.Close()
testSignupLoginRenew(t, db) testSignupLoginRenew(t, db)
testSqliteDBRandomSvgInterface(t, db) testSqliteDBRandomSvgInterface(t, db)
@@ -28,7 +29,7 @@ func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) {
nkodePolicy := models.NewDefaultNKodePolicy() nkodePolicy := models.NewDefaultNKodePolicy()
customerOrig, err := entities.NewCustomer(nkodePolicy) customerOrig, err := entities.NewCustomer(nkodePolicy)
assert.NoError(t, err) assert.NoError(t, err)
err = db.WriteNewCustomer(*customerOrig) err = db.CreateCustomer(*customerOrig)
assert.NoError(t, err) assert.NoError(t, err)
customer, err := db.GetCustomer(customerOrig.Id) customer, err := db.GetCustomer(customerOrig.Id)
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) {
} }
queue.AddEmail(email) queue.AddEmail(email)
} }
// CloseDb the queue after all emails are processed // Close the queue after all emails are processed
queue.Stop() queue.Stop()
assert.Equal(t, queue.FailedSendCount, 0) assert.Equal(t, queue.FailedSendCount, 0)

View File

@@ -5,6 +5,7 @@ import (
"go-nkode/config" "go-nkode/config"
"go-nkode/internal/models" "go-nkode/internal/models"
"go-nkode/internal/security" "go-nkode/internal/security"
"go-nkode/internal/sqlc"
"go-nkode/internal/utils" "go-nkode/internal/utils"
) )
@@ -83,3 +84,19 @@ func (c *Customer) RenewKeys() ([]uint64, []uint64, error) {
} }
return setXor, attrsXor, nil return setXor, attrsXor, nil
} }
func (c *Customer) ToSqlcCreateCustomerParams() sqlc.CreateCustomerParams {
return sqlc.CreateCustomerParams{
ID: uuid.UUID(c.Id).String(),
MaxNkodeLen: int64(c.NKodePolicy.MaxNkodeLen),
MinNkodeLen: int64(c.NKodePolicy.MinNkodeLen),
DistinctSets: int64(c.NKodePolicy.DistinctSets),
DistinctAttributes: int64(c.NKodePolicy.DistinctAttributes),
LockOut: int64(c.NKodePolicy.LockOut),
Expiration: int64(c.NKodePolicy.Expiration),
AttributeValues: c.Attributes.AttrBytes(),
SetValues: c.Attributes.SetBytes(),
LastRenew: utils.TimeStamp(),
CreatedAt: utils.TimeStamp(),
}
}

View File

@@ -37,11 +37,13 @@ func (u *User) RenewKeys(setXor []uint64, attrXor []uint64) error {
func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error { func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error {
setVals, err := customerAttributes.SetValsForKp(u.Kp) setVals, err := customerAttributes.SetValsForKp(u.Kp)
if err != nil {
return err
}
newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen) newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen)
if err != nil { if err != nil {
return err return err
} }
encipheredPasscode, err := newKeys.EncipherNKode(passcodeAttrIdx, customerAttributes) encipheredPasscode, err := newKeys.EncipherNKode(passcodeAttrIdx, customerAttributes)
if err != nil { if err != nil {
return err return err

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"fmt"
"github.com/google/uuid" "github.com/google/uuid"
"net/mail" "net/mail"
"strings" "strings"
@@ -99,10 +100,17 @@ func CustomerIdToString(customerId CustomerId) string {
type SessionId uuid.UUID type SessionId uuid.UUID
type UserId uuid.UUID type UserId uuid.UUID
func UserIdFromString(userId string) UserId {
id, err := uuid.Parse(userId)
if err != nil {
fmt.Errorf("unable to parse user id %+v", err)
}
return UserId(id)
}
func (s *SessionId) String() string { func (s *SessionId) String() string {
id := uuid.UUID(*s) id := uuid.UUID(*s)
return id.String() return id.String()
} }
type UserEmail string type UserEmail string

View File

@@ -0,0 +1,7 @@
package utils
import "time"
func TimeStamp() string {
return time.Now().Format(time.RFC3339)
}