implement sqlite write queue
This commit is contained in:
@@ -8,23 +8,50 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
||||
"go-nkode/util"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SqliteDB struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
stop bool
|
||||
writeQueue chan WriteTx
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type WriteTx struct {
|
||||
ErrChan chan error
|
||||
Query string
|
||||
Args []any
|
||||
}
|
||||
|
||||
const (
|
||||
writeBuffer = 1000
|
||||
)
|
||||
|
||||
func NewSqliteDB(path string) *SqliteDB {
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
log.Fatal("database didn't open ", err)
|
||||
}
|
||||
sqldb := SqliteDB{db: db}
|
||||
sqldb := SqliteDB{
|
||||
db: db,
|
||||
stop: false,
|
||||
writeQueue: make(chan WriteTx, writeBuffer),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for writeTx := range sqldb.writeQueue {
|
||||
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
|
||||
sqldb.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
return &sqldb
|
||||
}
|
||||
|
||||
func (d *SqliteDB) CloseDb() {
|
||||
d.stop = true
|
||||
d.wg.Wait()
|
||||
if err := d.db.Close(); err != nil {
|
||||
// If db.Close() returns an error, panic
|
||||
panic(fmt.Sprintf("Failed to close the database: %v", err))
|
||||
@@ -32,48 +59,16 @@ func (d *SqliteDB) CloseDb() {
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewCustomer(c Customer) error {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
insertCustomer := `
|
||||
query := `
|
||||
INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values)
|
||||
VALUES (?,?,?,?,?,?,?,?,?)
|
||||
`
|
||||
_, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
args := []any{uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()}
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewUser(u User) error {
|
||||
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
insertUser := `
|
||||
query := `
|
||||
INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
`
|
||||
@@ -83,32 +78,14 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
} else {
|
||||
renew = 0
|
||||
}
|
||||
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
args := []any{uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserNKode(u User) error {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
updateUser := `
|
||||
query := `
|
||||
UPDATE user
|
||||
SET renew = ?, refresh_token = ?, code = ?, mask = ?, attributes_per_key = ?, number_of_keys = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ?, max_nkode_len = ?, idx_interface = ?, svg_id_interface = ?
|
||||
WHERE username = ? AND customer_id = ?
|
||||
@@ -119,18 +96,103 @@ WHERE username = ? AND customer_id = ?
|
||||
} else {
|
||||
renew = 0
|
||||
}
|
||||
_, err = tx.Exec(updateUser, renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId))
|
||||
args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error {
|
||||
query := `
|
||||
UPDATE user SET idx_interface = ? WHERE id = ?
|
||||
`
|
||||
args := []any{util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
|
||||
query := `
|
||||
UPDATE user SET refresh_token = ? WHERE id = ?
|
||||
`
|
||||
args := []any{refreshToken, uuid.UUID(id).String()}
|
||||
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) Renew(id CustomerId) error {
|
||||
// TODO: How long does a renew take?
|
||||
customer, err := d.GetCustomer(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setXor, attrXor := customer.RenewKeys()
|
||||
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
|
||||
// TODO: replace with tx
|
||||
renewQuery := `
|
||||
UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?;
|
||||
`
|
||||
|
||||
userQuery := `
|
||||
SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ?
|
||||
`
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
|
||||
for rows.Next() {
|
||||
var userId string
|
||||
var alphaBytes []byte
|
||||
var setBytes []byte
|
||||
var attrsPerKey int
|
||||
var numbOfKeys int
|
||||
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := User{
|
||||
Id: UserId{},
|
||||
CustomerId: CustomerId{},
|
||||
Email: "",
|
||||
EncipheredPasscode: EncipheredNKode{},
|
||||
Kp: KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: UserCipherKeys{
|
||||
AlphaKey: util.ByteArrToUint64Arr(alphaBytes),
|
||||
SetKey: util.ByteArrToUint64Arr(setBytes),
|
||||
},
|
||||
Interface: UserInterface{},
|
||||
Renew: false,
|
||||
}
|
||||
err = user.RenewKeys(setXor, attrXor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewQuery += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;"
|
||||
renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
|
||||
}
|
||||
renewQuery += `
|
||||
`
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return d.addWriteTx(renewQuery, renewArgs)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
|
||||
err := user.RefreshPasscode(passcodeIdx, customerAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query := `
|
||||
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
|
||||
`
|
||||
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
@@ -191,7 +253,6 @@ func (d *SqliteDB) GetUser(username UserEmail, customerId CustomerId) (*User, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Commit()
|
||||
userSelect := `
|
||||
SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
|
||||
WHERE user.username = ? AND user.customer_id = ?
|
||||
@@ -263,115 +324,13 @@ WHERE user.username = ? AND user.customer_id = ?
|
||||
}
|
||||
user.Interface.Kp = &user.Kp
|
||||
user.CipherKeys.Kp = &user.Kp
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error {
|
||||
updateUserInterface := `
|
||||
UPDATE user SET idx_interface = ? WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error {
|
||||
updateUserRefreshToken := `
|
||||
UPDATE user SET refresh_token = ? WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SqliteDB) Renew(id CustomerId) error {
|
||||
customer, err := d.GetCustomer(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setXor, attrXor := customer.RenewKeys()
|
||||
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
|
||||
// TODO: replace with tx
|
||||
renewExec := `
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?;
|
||||
`
|
||||
|
||||
userQuery := `
|
||||
SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ?
|
||||
`
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
|
||||
for rows.Next() {
|
||||
var userId string
|
||||
var alphaBytes []byte
|
||||
var setBytes []byte
|
||||
var attrsPerKey int
|
||||
var numbOfKeys int
|
||||
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := User{
|
||||
Id: UserId{},
|
||||
CustomerId: CustomerId{},
|
||||
Email: "",
|
||||
EncipheredPasscode: EncipheredNKode{},
|
||||
Kp: KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: UserCipherKeys{
|
||||
AlphaKey: util.ByteArrToUint64Arr(alphaBytes),
|
||||
SetKey: util.ByteArrToUint64Arr(setBytes),
|
||||
},
|
||||
Interface: UserInterface{},
|
||||
Renew: false,
|
||||
}
|
||||
err = user.RenewKeys(setXor, attrXor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;"
|
||||
renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
|
||||
}
|
||||
renewExec += `
|
||||
COMMIT;
|
||||
`
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err = d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = d.db.Exec(renewExec, renewArgs...)
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
|
||||
err := user.RefreshPasscode(passcodeIdx, customerAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updateUser := `
|
||||
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
|
||||
`
|
||||
_, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) {
|
||||
ids, err := d.getRandomIds(kp.TotalAttrs())
|
||||
if err != nil {
|
||||
@@ -393,7 +352,6 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Commit()
|
||||
selectId := "SELECT svg FROM svg_icon where id = ?"
|
||||
svgs := make([]string, len(ids))
|
||||
for idx, id := range ids {
|
||||
@@ -409,15 +367,57 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return svgs, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) writeToDb(query string, args []any) error {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Sprintf("Write won't roll back %+v", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
_, err = tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) addWriteTx(query string, args []any) error {
|
||||
if d.stop {
|
||||
return errors.New("stopping database")
|
||||
}
|
||||
errChan := make(chan error)
|
||||
writeTx := WriteTx{
|
||||
Query: query,
|
||||
Args: args,
|
||||
ErrChan: errChan,
|
||||
}
|
||||
d.wg.Add(1)
|
||||
d.writeQueue <- writeTx
|
||||
return <-errChan
|
||||
}
|
||||
|
||||
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Commit()
|
||||
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -426,15 +426,20 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
||||
if !rows.Next() {
|
||||
return nil, errors.New("empty svg_icon table")
|
||||
}
|
||||
err = rows.Scan(&tableLen)
|
||||
|
||||
if err = rows.Scan(&tableLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perm, err := util.RandomPermutation(tableLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
perm, err := util.RandomPermutation(tableLen)
|
||||
for idx := range perm {
|
||||
perm[idx] += 1
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return perm[:count], nil
|
||||
|
||||
Reference in New Issue
Block a user