refactor code to fewer files; remove unused code

This commit is contained in:
2024-09-21 11:04:09 -05:00
parent b7a4a5cf4c
commit 2b3abb8fb2
29 changed files with 486 additions and 557 deletions

347
core/sqlite_db.go Normal file
View File

@@ -0,0 +1,347 @@
package core
import (
"database/sql"
"errors"
"fmt"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
"go-nkode/util"
"log"
)
type SqliteDB struct {
db *sql.DB
}
func NewSqliteDB(path string) *SqliteDB {
db, err := sql.Open("sqlite3", path)
if err != nil {
log.Fatal("database didn't open ", err)
}
sqldb := SqliteDB{db: db}
return &sqldb
}
func (d *SqliteDB) CloseDb() {
if err := d.db.Close(); err != nil {
// If db.Close() returns an error, panic
panic(fmt.Sprintf("Failed to close the database: %v", err))
}
}
func (d *SqliteDB) WriteNewCustomer(c Customer) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err))
}
}
}()
insertCustomer := `
INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values)
VALUES (?,?,?,?,?,?,?,?,?)
`
_, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes())
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (d *SqliteDB) WriteNewUser(u User) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
err = tx.Rollback()
if err != nil {
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
}
}
}()
insertUser := `
INSERT INTO user (id, username, renew, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
var renew int
if u.Renew {
renew = 1
} else {
renew = 0
}
_, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Username, renew, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId))
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) {
selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?`
rows, err := d.db.Query(selectCustomer, uuid.UUID(id))
if !rows.Next() {
return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err()))
}
var maxNKodeLen int
var minNKodeLen int
var distinctSets int
var distinctAttributes int
var lockOut int
var expiration int
var attributeValues []byte
var setValues []byte
err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues)
if err != nil {
return nil, err
}
if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id))
}
customer := Customer{
Id: id,
NKodePolicy: NKodePolicy{
MaxNkodeLen: maxNKodeLen,
MinNkodeLen: minNKodeLen,
DistinctSets: distinctSets,
DistinctAttributes: distinctAttributes,
LockOut: lockOut,
Expiration: expiration,
},
Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues),
}
return &customer, nil
}
func (d *SqliteDB) GetUser(username Username, customerId CustomerId) (*User, error) {
userSelect := `
SELECT id, renew, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user
WHERE user.username = ? AND user.customer_id = ?
`
rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String())
if !rows.Next() {
return nil, errors.New(fmt.Sprintf("no new rows for user %s of customer %s", string(username), uuid.UUID(customerId).String()))
}
var id string
var renewVal int
var code string
var mask string
var attrsPerKey int
var numbOfKeys int
var alphaKey []byte
var setKey []byte
var passKey []byte
var maskKey []byte
var salt []byte
var maxNKodeLen int
var idxInterface []byte
var svgIdInterface []byte
err = rows.Scan(&id, &renewVal, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
if rows.Next() {
return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId))
}
userId, err := uuid.Parse(id)
if err != nil {
return nil, err
}
var renew bool
if renewVal == 0 {
renew = false
} else {
renew = true
}
user := User{
Id: UserId(userId),
CustomerId: customerId,
Username: username,
EncipheredPasscode: EncipheredNKode{
Code: code,
Mask: mask,
},
Kp: KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: UserCipherKeys{
AlphaKey: util.ByteArrToUint64Arr(alphaKey),
SetKey: util.ByteArrToUint64Arr(setKey),
PassKey: util.ByteArrToUint64Arr(passKey),
MaskKey: util.ByteArrToUint64Arr(maskKey),
Salt: salt,
MaxNKodeLen: maxNKodeLen,
Kp: nil,
},
Interface: UserInterface{
IdxInterface: util.ByteArrToIntArr(idxInterface),
SvgId: util.ByteArrToIntArr(svgIdInterface),
Kp: nil,
},
Renew: renew,
}
user.Interface.Kp = &user.Kp
user.CipherKeys.Kp = &user.Kp
return &user, nil
}
func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error {
updateUserInterface := `
UPDATE user SET idx_interface = ? WHERE id = ?
`
_, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String())
return err
}
func (d *SqliteDB) Renew(id CustomerId) error {
customer, err := d.GetCustomer(id)
if err != nil {
return err
}
setXor, attrXor := customer.RenewKeys()
renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
// TODO: replace with tx
renewExec := `
BEGIN TRANSACTION;
UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?;
`
userQuery := `
SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ?
`
rows, err := d.db.Query(userQuery, uuid.UUID(id).String())
for rows.Next() {
var userId string
var alphaBytes []byte
var setBytes []byte
var attrsPerKey int
var numbOfKeys int
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
if err != nil {
return err
}
user := User{
Id: UserId{},
CustomerId: CustomerId{},
Username: "",
EncipheredPasscode: EncipheredNKode{},
Kp: KeypadDimension{
AttrsPerKey: attrsPerKey,
NumbOfKeys: numbOfKeys,
},
CipherKeys: UserCipherKeys{
AlphaKey: util.ByteArrToUint64Arr(alphaBytes),
SetKey: util.ByteArrToUint64Arr(setBytes),
},
Interface: UserInterface{},
Renew: false,
}
err = user.RenewKeys(setXor, attrXor)
if err != nil {
return err
}
renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;"
renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
}
renewExec += `
COMMIT;
`
_, err = d.db.Exec(renewExec, renewArgs...)
return err
}
func (d *SqliteDB) RefreshUser(user User, passcodeIdx []int, customerAttr CustomerAttributes) error {
err := user.RefreshPasscode(passcodeIdx, customerAttr)
if err != nil {
return err
}
updateUser := `
UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?;
`
_, err = d.db.Exec(updateUser, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String())
return err
}
func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) {
ids, err := d.getRandomIds(kp.TotalAttrs())
if err != nil {
return nil, err
}
return d.getSvgsById(ids)
}
func (d *SqliteDB) RandomSvgIdxInterface(kp KeypadDimension) (SvgIdInterface, error) {
return d.getRandomIds(kp.TotalAttrs())
}
func (d *SqliteDB) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) {
return d.getSvgsById(idxs)
}
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
selectId := "SELECT svg FROM svg_icon where id = ?"
svgs := make([]string, len(ids))
for idx, id := range ids {
rows, err := d.db.Query(selectId, id)
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, errors.New(fmt.Sprintf("id not found: %d", id))
}
err = rows.Scan(&svgs[idx])
if err != nil {
return nil, err
}
}
return svgs, nil
}
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
rows, err := d.db.Query("SELECT COUNT(*) as count FROM svg_icon;")
if err != nil {
return nil, err
}
var tableLen int
if !rows.Next() {
return nil, errors.New("empty svg_icon table")
}
err = rows.Scan(&tableLen)
if err != nil {
return nil, err
}
perm, err := util.RandomPermutation(tableLen)
for idx := range perm {
perm[idx] += 1
}
if err != nil {
return nil, err
}
return perm[:count], nil
}