idiomatic project structure
This commit is contained in:
564
internal/db/sqlite_db.go
Normal file
564
internal/db/sqlite_db.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
||||
"go-nkode/config"
|
||||
"go-nkode/internal/models"
|
||||
"go-nkode/internal/security"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SqliteDB struct {
|
||||
db *sql.DB
|
||||
stop bool
|
||||
writeQueue chan WriteTx
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type WriteTx struct {
|
||||
ErrChan chan error
|
||||
Query string
|
||||
Args []any
|
||||
}
|
||||
|
||||
const (
|
||||
writeBuffer = 1000
|
||||
)
|
||||
|
||||
func NewSqliteDB(path string) *SqliteDB {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
log.Fatal("database didn't open ", err)
|
||||
}
|
||||
sqldb := SqliteDB{
|
||||
db: db,
|
||||
stop: false,
|
||||
writeQueue: make(chan WriteTx, writeBuffer),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for writeTx := range sqldb.writeQueue {
|
||||
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
|
||||
sqldb.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
return &sqldb
|
||||
}
|
||||
|
||||
func (d *SqliteDB) CloseDb() {
|
||||
d.stop = true
|
||||
d.wg.Wait()
|
||||
if err := d.db.Close(); err != nil {
|
||||
// If db.Close() returns an error, panic
|
||||
panic(fmt.Sprintf("Failed to close the database: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewCustomer(c models.Customer) error {
|
||||
query := `
|
||||
INSERT INTO customer (
|
||||
id
|
||||
,max_nkode_len
|
||||
,min_nkode_len
|
||||
,distinct_sets
|
||||
,distinct_attributes
|
||||
,lock_out
|
||||
,expiration
|
||||
,attribute_values
|
||||
,set_values
|
||||
,last_renew
|
||||
,created_at
|
||||
)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
`
|
||||
args := []any{
|
||||
uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets,
|
||||
c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration,
|
||||
c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(),
|
||||
}
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewUser(u models.User) error {
|
||||
query := `
|
||||
INSERT INTO user (
|
||||
id
|
||||
,email
|
||||
,renew
|
||||
,refresh_token
|
||||
,customer_id
|
||||
,code
|
||||
,mask
|
||||
,attributes_per_key
|
||||
,number_of_keys
|
||||
,alpha_key
|
||||
,set_key
|
||||
,pass_key
|
||||
,mask_key
|
||||
,salt
|
||||
,max_nkode_len
|
||||
,idx_interface
|
||||
,svg_id_interface
|
||||
,created_at
|
||||
)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
`
|
||||
var renew int
|
||||
if u.Renew {
|
||||
renew = 1
|
||||
} else {
|
||||
renew = 0
|
||||
}
|
||||
|
||||
args := []any{
|
||||
uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId),
|
||||
u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys,
|
||||
security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
||||
security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
||||
u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface),
|
||||
security.IntArrToByteArr(u.Interface.SvgId), timeStamp(),
|
||||
}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserNKode(u models.User) error {
|
||||
query := `
|
||||
UPDATE user
|
||||
SET renew = ?
|
||||
,refresh_token = ?
|
||||
,code = ?
|
||||
,mask = ?
|
||||
,attributes_per_key = ?
|
||||
,number_of_keys = ?
|
||||
,alpha_key = ?
|
||||
,set_key = ?
|
||||
,pass_key = ?
|
||||
,mask_key = ?
|
||||
,salt = ?
|
||||
,max_nkode_len = ?
|
||||
,idx_interface = ?
|
||||
,svg_id_interface = ?
|
||||
WHERE email = ? AND customer_id = ?
|
||||
`
|
||||
var renew int
|
||||
if u.Renew {
|
||||
renew = 1
|
||||
} else {
|
||||
renew = 0
|
||||
}
|
||||
args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui models.UserInterface) error {
|
||||
query := `
|
||||
UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ?
|
||||
`
|
||||
args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
|
||||
query := `
|
||||
UPDATE user SET refresh_token = ? WHERE id = ?
|
||||
`
|
||||
args := []any{refreshToken, uuid.UUID(id).String()}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) Renew(id models.CustomerId) error {
|
||||
// TODO: How long does a renew take?
|
||||
customer, err := d.GetCustomer(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setXor, attrXor, err := customer.RenewKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
|
||||
// TODO: replace with tx
|
||||
renewQuery := `
|
||||
UPDATE customer
|
||||
SET attribute_values = ?, set_values = ?
|
||||
WHERE id = ?;
|
||||
`
|
||||
|
||||
userQuery := `
|
||||
SELECT
|
||||
id
|
||||
,alpha_key
|
||||
,set_key
|
||||
,attributes_per_key
|
||||
,number_of_keys
|
||||
FROM user
|
||||
WHERE customer_id = ?
|
||||
`
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
|
||||
for rows.Next() {
|
||||
var userId string
|
||||
var alphaBytes []byte
|
||||
var setBytes []byte
|
||||
var attrsPerKey int
|
||||
var numbOfKeys int
|
||||
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := models.User{
|
||||
Id: models.UserId{},
|
||||
CustomerId: models.CustomerId{},
|
||||
Email: "",
|
||||
EncipheredPasscode: models.EncipheredNKode{},
|
||||
Kp: models.KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: models.UserCipherKeys{
|
||||
AlphaKey: security.ByteArrToUint64Arr(alphaBytes),
|
||||
SetKey: security.ByteArrToUint64Arr(setBytes),
|
||||
},
|
||||
Interface: models.UserInterface{},
|
||||
Renew: false,
|
||||
}
|
||||
err = user.RenewKeys(setXor, attrXor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewQuery += `
|
||||
UPDATE user
|
||||
SET alpha_key = ?, set_key = ?, renew = ?
|
||||
WHERE id = ?;
|
||||
`
|
||||
renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
|
||||
}
|
||||
renewQuery += `
|
||||
`
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.addWriteTx(renewQuery, renewArgs)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RefreshUserPasscode(user models.User, passcodeIdx []int, customerAttr models.CustomerAttributes) error {
|
||||
err := user.RefreshPasscode(passcodeIdx, customerAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query := `
|
||||
UPDATE user
|
||||
SET
|
||||
renew = ?
|
||||
,code = ?
|
||||
,mask = ?
|
||||
,alpha_key = ?
|
||||
,set_key = ?
|
||||
,pass_key = ?
|
||||
,mask_key = ?
|
||||
,salt = ?
|
||||
WHERE id = ?;
|
||||
`
|
||||
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*models.Customer, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
selectCustomer := `
|
||||
SELECT
|
||||
max_nkode_len
|
||||
,min_nkode_len
|
||||
,distinct_sets
|
||||
,distinct_attributes
|
||||
,lock_out
|
||||
,expiration
|
||||
,attribute_values
|
||||
,set_values
|
||||
FROM customer
|
||||
WHERE id = ?
|
||||
`
|
||||
rows, err := tx.Query(selectCustomer, uuid.UUID(id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
log.Printf("no new row for customer %s with err %s", id, rows.Err())
|
||||
return nil, config.ErrCustomerDne
|
||||
}
|
||||
|
||||
var maxNKodeLen int
|
||||
var minNKodeLen int
|
||||
var distinctSets int
|
||||
var distinctAttributes int
|
||||
var lockOut int
|
||||
var expiration int
|
||||
var attributeValues []byte
|
||||
var setValues []byte
|
||||
err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customer := models.Customer{
|
||||
Id: id,
|
||||
NKodePolicy: models.NKodePolicy{
|
||||
MaxNkodeLen: maxNKodeLen,
|
||||
MinNkodeLen: minNKodeLen,
|
||||
DistinctSets: distinctSets,
|
||||
DistinctAttributes: distinctAttributes,
|
||||
LockOut: lockOut,
|
||||
Expiration: expiration,
|
||||
},
|
||||
Attributes: models.NewCustomerAttributesFromBytes(attributeValues, setValues),
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &customer, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*models.User, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSelect := `
|
||||
SELECT
|
||||
id
|
||||
,renew
|
||||
,refresh_token
|
||||
,code
|
||||
,mask
|
||||
,attributes_per_key
|
||||
,number_of_keys
|
||||
,alpha_key
|
||||
,set_key
|
||||
,pass_key
|
||||
,mask_key
|
||||
,salt
|
||||
,max_nkode_len
|
||||
,idx_interface
|
||||
,svg_id_interface
|
||||
FROM user
|
||||
WHERE user.email = ? AND user.customer_id = ?
|
||||
`
|
||||
rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String())
|
||||
if !rows.Next() {
|
||||
return nil, nil
|
||||
}
|
||||
var (
|
||||
id string
|
||||
renewVal int
|
||||
refreshToken string
|
||||
code string
|
||||
mask string
|
||||
attrsPerKey int
|
||||
numbOfKeys int
|
||||
alphaKey []byte
|
||||
setKey []byte
|
||||
passKey []byte
|
||||
maskKey []byte
|
||||
salt []byte
|
||||
maxNKodeLen int
|
||||
idxInterface []byte
|
||||
svgIdInterface []byte
|
||||
)
|
||||
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
|
||||
|
||||
userId, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var renew bool
|
||||
if renewVal == 0 {
|
||||
renew = false
|
||||
} else {
|
||||
renew = true
|
||||
}
|
||||
|
||||
user := models.User{
|
||||
Id: models.UserId(userId),
|
||||
CustomerId: customerId,
|
||||
Email: email,
|
||||
EncipheredPasscode: models.EncipheredNKode{
|
||||
Code: code,
|
||||
Mask: mask,
|
||||
},
|
||||
Kp: models.KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: models.UserCipherKeys{
|
||||
AlphaKey: security.ByteArrToUint64Arr(alphaKey),
|
||||
SetKey: security.ByteArrToUint64Arr(setKey),
|
||||
PassKey: security.ByteArrToUint64Arr(passKey),
|
||||
MaskKey: security.ByteArrToUint64Arr(maskKey),
|
||||
Salt: salt,
|
||||
MaxNKodeLen: maxNKodeLen,
|
||||
Kp: nil,
|
||||
},
|
||||
Interface: models.UserInterface{
|
||||
IdxInterface: security.ByteArrToIntArr(idxInterface),
|
||||
SvgId: security.ByteArrToIntArr(svgIdInterface),
|
||||
Kp: nil,
|
||||
},
|
||||
Renew: renew,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
user.Interface.Kp = &user.Kp
|
||||
user.CipherKeys.Kp = &user.Kp
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) {
|
||||
ids, err := d.getRandomIds(kp.TotalAttrs())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.getSvgsById(ids)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) {
|
||||
return d.getRandomIds(kp.TotalAttrs())
|
||||
}
|
||||
|
||||
func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) {
|
||||
return d.getSvgsById(idxs)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectId := `
|
||||
SELECT svg
|
||||
FROM svg_icon
|
||||
WHERE id = ?
|
||||
`
|
||||
svgs := make([]string, len(ids))
|
||||
for idx, id := range ids {
|
||||
rows, err := tx.Query(selectId, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rows.Next() {
|
||||
log.Printf("id not found: %d", id)
|
||||
return nil, config.ErrSvgDne
|
||||
}
|
||||
if err = rows.Scan(&svgs[idx]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return svgs, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) writeToDb(query string, args []any) error {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatalf("fatal error: write won't roll back %+v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
if _, err = tx.Exec(query, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) addWriteTx(query string, args []any) error {
|
||||
if d.stop {
|
||||
return config.ErrStoppingDatabase
|
||||
}
|
||||
errChan := make(chan error)
|
||||
writeTx := WriteTx{
|
||||
Query: query,
|
||||
Args: args,
|
||||
ErrChan: errChan,
|
||||
}
|
||||
d.wg.Add(1)
|
||||
d.writeQueue <- writeTx
|
||||
return <-errChan
|
||||
}
|
||||
|
||||
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return nil, config.ErrSqliteTx
|
||||
}
|
||||
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return nil, config.ErrSqliteTx
|
||||
}
|
||||
var tableLen int
|
||||
if !rows.Next() {
|
||||
return nil, config.ErrEmptySvgTable
|
||||
}
|
||||
|
||||
if err = rows.Scan(&tableLen); err != nil {
|
||||
log.Print(err)
|
||||
return nil, config.ErrSqliteTx
|
||||
}
|
||||
perm, err := security.RandomPermutation(tableLen)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for idx := range perm {
|
||||
perm[idx] += 1
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
log.Print(err)
|
||||
return nil, config.ErrSqliteTx
|
||||
}
|
||||
|
||||
return perm[:count], nil
|
||||
}
|
||||
|
||||
func timeStamp() string {
|
||||
return time.Now().Format(time.RFC3339)
|
||||
}
|
||||
Reference in New Issue
Block a user