idiomatic project structure

This commit is contained in:
2024-11-26 11:31:46 -06:00
parent 1200380341
commit 052f95702d
44 changed files with 717 additions and 481 deletions

564
internal/db/sqlite_db.go Normal file
View File

@@ -0,0 +1,564 @@
package db
import (
"database/sql"
"fmt"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
"go-nkode/config"
"go-nkode/internal/models"
"go-nkode/internal/security"
"log"
"sync"
"time"
)
type SqliteDB struct {
db *sql.DB
stop bool
writeQueue chan WriteTx
wg sync.WaitGroup
}
type WriteTx struct {
ErrChan chan error
Query string
Args []any
}
const (
writeBuffer = 1000
)
func NewSqliteDB(path string) *SqliteDB {
db, err := sql.Open("sqlite3", path)
if err != nil {
log.Fatal("database didn't open ", err)
}
sqldb := SqliteDB{
db: db,
stop: false,
writeQueue: make(chan WriteTx, writeBuffer),
}
go func() {
for writeTx := range sqldb.writeQueue {
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
sqldb.wg.Done()
}
}()
return &sqldb
}
func (d *SqliteDB) CloseDb() {
d.stop = true
d.wg.Wait()
if err := d.db.Close(); err != nil {
// If db.Close() returns an error, panic
panic(fmt.Sprintf("Failed to close the database: %v", err))
}
}
func (d *SqliteDB) WriteNewCustomer(c models.Customer) error {
query := `
INSERT INTO customer (
id
,max_nkode_len
,min_nkode_len
,distinct_sets
,distinct_attributes
,lock_out
,expiration
,attribute_values
,set_values
,last_renew
,created_at
)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
`
args := []any{
uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets,
c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration,
c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(),
}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) WriteNewUser(u models.User) error {
query := `
INSERT INTO user (
id
,email
,renew
,refresh_token
,customer_id
,code
,mask
,attributes_per_key
,number_of_keys
,alpha_key
,set_key
,pass_key
,mask_key
,salt
,max_nkode_len
,idx_interface
,svg_id_interface
,created_at
)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
var renew int
if u.Renew {
renew = 1
} else {
renew = 0
}
args := []any{
uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId),
u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys,
security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface),
security.IntArrToByteArr(u.Interface.SvgId), timeStamp(),
}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) UpdateUserNKode(u models.User) error {
query := `
UPDATE user
SET renew = ?
,refresh_token = ?
,code = ?
,mask = ?
,attributes_per_key = ?
,number_of_keys = ?
,alpha_key = ?
,set_key = ?
,pass_key = ?
,mask_key = ?
,salt = ?
,max_nkode_len = ?
,idx_interface = ?
,svg_id_interface = ?
WHERE email = ? AND customer_id = ?
`
var renew int
if u.Renew {
renew = 1
} else {
renew = 0
}
args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui models.UserInterface) error {
query := `
UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ?
`
args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
query := `
UPDATE user SET refresh_token = ? WHERE id = ?
`
args := []any{refreshToken, uuid.UUID(id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) Renew(id models.CustomerId) error {
// TODO: How long does a renew take?
customer, err := d.GetCustomer(id)
if err != nil {
return err
}
setXor, attrXor, err := customer.RenewKeys()
if err != nil {
return err
}
renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx
renewQuery := `
UPDATE customer
SET attribute_values = ?, set_values = ?
WHERE id = ?;
`
userQuery := `
SELECT
id
,alpha_key
,set_key
,attributes_per_key
,number_of_keys
FROM user
WHERE customer_id = ?
`
tx, err := d.db.Begin()
if err != nil {
return err
}
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
for rows.Next() {
var userId string
var alphaBytes []byte
var setBytes []byte
var attrsPerKey int
var numbOfKeys int
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
if err != nil {
return err
}
user := models.User{
Id: models.UserId{},
CustomerId: models.CustomerId{},
Email: "",
EncipheredPasscode: models.EncipheredNKode{},
Kp: models.KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: models.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(alphaBytes),
SetKey: security.ByteArrToUint64Arr(setBytes),
},
Interface: models.UserInterface{},
Renew: false,
}
err = user.RenewKeys(setXor, attrXor)
if err != nil {
return err
}
renewQuery += `
UPDATE user
SET alpha_key = ?, set_key = ?, renew = ?
WHERE id = ?;
`
renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
}
renewQuery += `
`
err = tx.Commit()
if err != nil {
return err
}
return d.addWriteTx(renewQuery, renewArgs)
}
func (d *SqliteDB) RefreshUserPasscode(user models.User, passcodeIdx []int, customerAttr models.CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil {
return err
}
query := `
UPDATE user
SET
renew = ?
,code = ?
,mask = ?
,alpha_key = ?
,set_key = ?
,pass_key = ?
,mask_key = ?
,salt = ?
WHERE id = ?;
`
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
return d.addWriteTx(query, args)
}
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*models.Customer, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
}
}
}()
selectCustomer := `
SELECT
max_nkode_len
,min_nkode_len
,distinct_sets
,distinct_attributes
,lock_out
,expiration
,attribute_values
,set_values
FROM customer
WHERE id = ?
`
rows, err := tx.Query(selectCustomer, uuid.UUID(id))
if err != nil {
return nil, err
}
if !rows.Next() {
log.Printf("no new row for customer %s with err %s", id, rows.Err())
return nil, config.ErrCustomerDne
}
var maxNKodeLen int
var minNKodeLen int
var distinctSets int
var distinctAttributes int
var lockOut int
var expiration int
var attributeValues []byte
var setValues []byte
err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues)
if err != nil {
return nil, err
}
customer := models.Customer{
Id: id,
NKodePolicy: models.NKodePolicy{
MaxNkodeLen: maxNKodeLen,
MinNkodeLen: minNKodeLen,
DistinctSets: distinctSets,
DistinctAttributes: distinctAttributes,
LockOut: lockOut,
Expiration: expiration,
},
Attributes: models.NewCustomerAttributesFromBytes(attributeValues, setValues),
}
if err = tx.Commit(); err != nil {
return nil, err
}
return &customer, nil
}
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*models.User, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
userSelect := `
SELECT
id
,renew
,refresh_token
,code
,mask
,attributes_per_key
,number_of_keys
,alpha_key
,set_key
,pass_key
,mask_key
,salt
,max_nkode_len
,idx_interface
,svg_id_interface
FROM user
WHERE user.email = ? AND user.customer_id = ?
`
rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String())
if !rows.Next() {
return nil, nil
}
var (
id string
renewVal int
refreshToken string
code string
mask string
attrsPerKey int
numbOfKeys int
alphaKey []byte
setKey []byte
passKey []byte
maskKey []byte
salt []byte
maxNKodeLen int
idxInterface []byte
svgIdInterface []byte
)
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
userId, err := uuid.Parse(id)
if err != nil {
return nil, err
}
var renew bool
if renewVal == 0 {
renew = false
} else {
renew = true
}
user := models.User{
Id: models.UserId(userId),
CustomerId: customerId,
Email: email,
EncipheredPasscode: models.EncipheredNKode{
Code: code,
Mask: mask,
},
Kp: models.KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: models.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(alphaKey),
SetKey: security.ByteArrToUint64Arr(setKey),
PassKey: security.ByteArrToUint64Arr(passKey),
MaskKey: security.ByteArrToUint64Arr(maskKey),
Salt: salt,
MaxNKodeLen: maxNKodeLen,
Kp: nil,
},
Interface: models.UserInterface{
IdxInterface: security.ByteArrToIntArr(idxInterface),
SvgId: security.ByteArrToIntArr(svgIdInterface),
Kp: nil,
},
Renew: renew,
RefreshToken: refreshToken,
}
user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp
if err = tx.Commit(); err != nil {
return nil, err
}
return &user, nil
}
func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) {
ids, err := d.getRandomIds(kp.TotalAttrs())
if err != nil {
return nil, err
}
return d.getSvgsById(ids)
}
func (d *SqliteDB) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) {
return d.getRandomIds(kp.TotalAttrs())
}
func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) {
return d.getSvgsById(idxs)
}
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
selectId := `
SELECT svg
FROM svg_icon
WHERE id = ?
`
svgs := make([]string, len(ids))
for idx, id := range ids {
rows, err := tx.Query(selectId, id)
if err != nil {
return nil, err
}
if !rows.Next() {
log.Printf("id not found: %d", id)
return nil, config.ErrSvgDne
}
if err = rows.Scan(&svgs[idx]); err != nil {
return nil, err
}
}
if err = tx.Commit(); err != nil {
return nil, err
}
return svgs, nil
}
func (d *SqliteDB) writeToDb(query string, args []any) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatalf("fatal error: write won't roll back %+v", err)
}
}
}()
if _, err = tx.Exec(query, args...); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
func (d *SqliteDB) addWriteTx(query string, args []any) error {
if d.stop {
return config.ErrStoppingDatabase
}
errChan := make(chan error)
writeTx := WriteTx{
Query: query,
Args: args,
ErrChan: errChan,
}
d.wg.Add(1)
d.writeQueue <- writeTx
return <-errChan
}
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
tx, err := d.db.Begin()
if err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
if err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
var tableLen int
if !rows.Next() {
return nil, config.ErrEmptySvgTable
}
if err = rows.Scan(&tableLen); err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
perm, err := security.RandomPermutation(tableLen)
if err != nil {
return nil, err
}
for idx := range perm {
perm[idx] += 1
}
if err = tx.Commit(); err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
return perm[:count], nil
}
func timeStamp() string {
return time.Now().Format(time.RFC3339)
}