package db import ( "database/sql" "fmt" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "go-nkode/config" "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "log" "sync" "time" ) type SqliteDB struct { db *sql.DB stop bool writeQueue chan WriteTx wg sync.WaitGroup } type WriteTx struct { ErrChan chan error Query string Args []any } const ( writeBuffer = 1000 ) func NewSqliteDB(path string) *SqliteDB { db, err := sql.Open("sqlite3", path) if err != nil { log.Fatal("database didn't open ", err) } sqldb := SqliteDB{ db: db, stop: false, writeQueue: make(chan WriteTx, writeBuffer), } go func() { for writeTx := range sqldb.writeQueue { writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args) sqldb.wg.Done() } }() return &sqldb } func (d *SqliteDB) CloseDb() { d.stop = true d.wg.Wait() if err := d.db.Close(); err != nil { // If db.Close() returns an error, panic panic(fmt.Sprintf("Failed to close the database: %v", err)) } } func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error { query := ` INSERT INTO customer ( id ,max_nkode_len ,min_nkode_len ,distinct_sets ,distinct_attributes ,lock_out ,expiration ,attribute_values ,set_values ,last_renew ,created_at ) VALUES (?,?,?,?,?,?,?,?,?,?,?) ` args := []any{ uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(), } return d.addWriteTx(query, args) } func (d *SqliteDB) WriteNewUser(u entities.User) error { query := ` INSERT INTO user ( id ,email ,renew ,refresh_token ,customer_id ,code ,mask ,attributes_per_key ,number_of_keys ,alpha_key ,set_key ,pass_key ,mask_key ,salt ,max_nkode_len ,idx_interface ,svg_id_interface ,created_at ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` var renew int if u.Renew { renew = 1 } else { renew = 0 } args := []any{ uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), timeStamp(), } return d.addWriteTx(query, args) } func (d *SqliteDB) UpdateUserNKode(u entities.User) error { query := ` UPDATE user SET renew = ? ,refresh_token = ? ,code = ? ,mask = ? ,attributes_per_key = ? ,number_of_keys = ? ,alpha_key = ? ,set_key = ? ,pass_key = ? ,mask_key = ? ,salt = ? ,max_nkode_len = ? ,idx_interface = ? ,svg_id_interface = ? WHERE email = ? AND customer_id = ? ` var renew int if u.Renew { renew = 1 } else { renew = 0 } args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)} return d.addWriteTx(query, args) } func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { query := ` UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ? ` args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()} return d.addWriteTx(query, args) } func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { query := ` UPDATE user SET refresh_token = ? WHERE id = ? ` args := []any{refreshToken, uuid.UUID(id).String()} return d.addWriteTx(query, args) } func (d *SqliteDB) Renew(id models.CustomerId) error { // TODO: How long does a renew take? customer, err := d.GetCustomer(id) if err != nil { return err } setXor, attrXor, err := customer.RenewKeys() if err != nil { return err } renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} // TODO: replace with tx renewQuery := ` UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; ` userQuery := ` SELECT id ,alpha_key ,set_key ,attributes_per_key ,number_of_keys FROM user WHERE customer_id = ? ` tx, err := d.db.Begin() if err != nil { return err } rows, err := tx.Query(userQuery, uuid.UUID(id).String()) for rows.Next() { var userId string var alphaBytes []byte var setBytes []byte var attrsPerKey int var numbOfKeys int err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) if err != nil { return err } user := entities.User{ Id: models.UserId{}, CustomerId: models.CustomerId{}, Email: "", EncipheredPasscode: models.EncipheredNKode{}, Kp: entities.KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(alphaBytes), SetKey: security.ByteArrToUint64Arr(setBytes), }, Interface: entities.UserInterface{}, Renew: false, } err = user.RenewKeys(setXor, attrXor) if err != nil { return err } renewQuery += ` UPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?; ` renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) } renewQuery += ` ` err = tx.Commit() if err != nil { return err } return d.addWriteTx(renewQuery, renewArgs) } func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { err := user.RefreshPasscode(passcodeIdx, customerAttr) if err != nil { return err } query := ` UPDATE user SET renew = ? ,code = ? ,mask = ? ,alpha_key = ? ,set_key = ? ,pass_key = ? ,mask_key = ? ,salt = ? WHERE id = ?; ` args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()} return d.addWriteTx(query, args) } func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { tx, err := d.db.Begin() if err != nil { return nil, err } defer func() { if err != nil { err = tx.Rollback() if err != nil { log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) } } }() selectCustomer := ` SELECT max_nkode_len ,min_nkode_len ,distinct_sets ,distinct_attributes ,lock_out ,expiration ,attribute_values ,set_values FROM customer WHERE id = ? ` rows, err := tx.Query(selectCustomer, uuid.UUID(id)) if err != nil { return nil, err } if !rows.Next() { log.Printf("no new row for customer %s with err %s", id, rows.Err()) return nil, config.ErrCustomerDne } var maxNKodeLen int var minNKodeLen int var distinctSets int var distinctAttributes int var lockOut int var expiration int var attributeValues []byte var setValues []byte err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues) if err != nil { return nil, err } customer := entities.Customer{ Id: id, NKodePolicy: models.NKodePolicy{ MaxNkodeLen: maxNKodeLen, MinNkodeLen: minNKodeLen, DistinctSets: distinctSets, DistinctAttributes: distinctAttributes, LockOut: lockOut, Expiration: expiration, }, Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues), } if err = tx.Commit(); err != nil { return nil, err } return &customer, nil } func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { tx, err := d.db.Begin() if err != nil { return nil, err } userSelect := ` SELECT id ,renew ,refresh_token ,code ,mask ,attributes_per_key ,number_of_keys ,alpha_key ,set_key ,pass_key ,mask_key ,salt ,max_nkode_len ,idx_interface ,svg_id_interface FROM user WHERE user.email = ? AND user.customer_id = ? ` rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String()) if !rows.Next() { return nil, nil } var ( id string renewVal int refreshToken string code string mask string attrsPerKey int numbOfKeys int alphaKey []byte setKey []byte passKey []byte maskKey []byte salt []byte maxNKodeLen int idxInterface []byte svgIdInterface []byte ) err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) userId, err := uuid.Parse(id) if err != nil { return nil, err } var renew bool if renewVal == 0 { renew = false } else { renew = true } user := entities.User{ Id: models.UserId(userId), CustomerId: customerId, Email: email, EncipheredPasscode: models.EncipheredNKode{ Code: code, Mask: mask, }, Kp: entities.KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(alphaKey), SetKey: security.ByteArrToUint64Arr(setKey), PassKey: security.ByteArrToUint64Arr(passKey), MaskKey: security.ByteArrToUint64Arr(maskKey), Salt: salt, MaxNKodeLen: maxNKodeLen, Kp: nil, }, Interface: entities.UserInterface{ IdxInterface: security.ByteArrToIntArr(idxInterface), SvgId: security.ByteArrToIntArr(svgIdInterface), Kp: nil, }, Renew: renew, RefreshToken: refreshToken, } user.Interface.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp if err = tx.Commit(); err != nil { return nil, err } return &user, nil } func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { return nil, err } return d.getSvgsById(ids) } func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { return d.getRandomIds(kp.TotalAttrs()) } func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) { return d.getSvgsById(idxs) } func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { tx, err := d.db.Begin() if err != nil { return nil, err } selectId := ` SELECT svg FROM svg_icon WHERE id = ? ` svgs := make([]string, len(ids)) for idx, id := range ids { rows, err := tx.Query(selectId, id) if err != nil { return nil, err } if !rows.Next() { log.Printf("id not found: %d", id) return nil, config.ErrSvgDne } if err = rows.Scan(&svgs[idx]); err != nil { return nil, err } } if err = tx.Commit(); err != nil { return nil, err } return svgs, nil } func (d *SqliteDB) writeToDb(query string, args []any) error { tx, err := d.db.Begin() if err != nil { return err } defer func() { if err != nil { err = tx.Rollback() if err != nil { log.Fatalf("fatal error: write won't roll back %+v", err) } } }() if _, err = tx.Exec(query, args...); err != nil { return err } if err = tx.Commit(); err != nil { return err } return nil } func (d *SqliteDB) addWriteTx(query string, args []any) error { if d.stop { return config.ErrStoppingDatabase } errChan := make(chan error) writeTx := WriteTx{ Query: query, Args: args, ErrChan: errChan, } d.wg.Add(1) d.writeQueue <- writeTx return <-errChan } func (d *SqliteDB) getRandomIds(count int) ([]int, error) { tx, err := d.db.Begin() if err != nil { log.Print(err) return nil, config.ErrSqliteTx } rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") if err != nil { log.Print(err) return nil, config.ErrSqliteTx } var tableLen int if !rows.Next() { return nil, config.ErrEmptySvgTable } if err = rows.Scan(&tableLen); err != nil { log.Print(err) return nil, config.ErrSqliteTx } perm, err := security.RandomPermutation(tableLen) if err != nil { return nil, err } for idx := range perm { perm[idx] += 1 } if err = tx.Commit(); err != nil { log.Print(err) return nil, config.ErrSqliteTx } return perm[:count], nil } func timeStamp() string { return time.Now().Format(time.RFC3339) }