package core import ( "database/sql" "errors" "fmt" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "go-nkode/util" "log" ) type SqliteDB struct { db *sql.DB } func NewSqliteDB(path string) *SqliteDB { db, err := sql.Open("sqlite3", path) if err != nil { log.Fatal("database didn't open ", err) } sqldb := SqliteDB{db: db} return &sqldb } func (d *SqliteDB) CloseDb() { if err := d.db.Close(); err != nil { // If db.Close() returns an error, panic panic(fmt.Sprintf("Failed to close the database: %v", err)) } } func (d *SqliteDB) WriteNewCustomer(c Customer) error { tx, err := d.db.Begin() if err != nil { return err } defer func() { if err != nil { err = tx.Rollback() if err != nil { log.Fatal(fmt.Sprintf("Write new customer won't roll back %+v", err)) } } }() insertCustomer := ` INSERT INTO customer (id, max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values) VALUES (?,?,?,?,?,?,?,?,?) ` _, err = tx.Exec(insertCustomer, uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets, c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration, c.Attributes.AttrBytes(), c.Attributes.SetBytes()) if err != nil { return err } err = tx.Commit() if err != nil { return err } return nil } func (d *SqliteDB) WriteNewUser(u User) error { tx, err := d.db.Begin() if err != nil { return err } defer func() { if err != nil { err = tx.Rollback() if err != nil { log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err)) } } }() insertUser := ` INSERT INTO user (id, username, renew, refresh_token, customer_id, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) ` var renew int if u.Renew { renew = 1 } else { renew = 0 } _, err = tx.Exec(insertUser, uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId), u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, util.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(u.CipherKeys.SetKey), util.Uint64ArrToByteArr(u.CipherKeys.PassKey), util.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, util.IntArrToByteArr(u.Interface.IdxInterface), util.IntArrToByteArr(u.Interface.SvgId)) if err != nil { return err } err = tx.Commit() if err != nil { return err } return nil } func (d *SqliteDB) GetCustomer(id CustomerId) (*Customer, error) { selectCustomer := `SELECT max_nkode_len, min_nkode_len, distinct_sets, distinct_attributes, lock_out, expiration, attribute_values, set_values FROM customer WHERE id = ?` rows, err := d.db.Query(selectCustomer, uuid.UUID(id)) if !rows.Next() { return nil, errors.New(fmt.Sprintf("no new row for customer %s with err %s", id, rows.Err())) } var maxNKodeLen int var minNKodeLen int var distinctSets int var distinctAttributes int var lockOut int var expiration int var attributeValues []byte var setValues []byte err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues) if err != nil { return nil, err } if rows.Next() { return nil, errors.New(fmt.Sprintf("too many rows for customer %s", id)) } customer := Customer{ Id: id, NKodePolicy: NKodePolicy{ MaxNkodeLen: maxNKodeLen, MinNkodeLen: minNKodeLen, DistinctSets: distinctSets, DistinctAttributes: distinctAttributes, LockOut: lockOut, Expiration: expiration, }, Attributes: NewCustomerAttributesFromBytes(attributeValues, setValues), } return &customer, nil } func (d *SqliteDB) GetUser(username Email, customerId CustomerId) (*User, error) { userSelect := ` SELECT id, renew, refresh_token, code, mask, attributes_per_key, number_of_keys, alpha_key, set_key, pass_key, mask_key, salt, max_nkode_len, idx_interface, svg_id_interface FROM user WHERE user.username = ? AND user.customer_id = ? ` rows, err := d.db.Query(userSelect, string(username), uuid.UUID(customerId).String()) if !rows.Next() { return nil, errors.New(fmt.Sprintf("no new rows for user %s of customer %s", string(username), uuid.UUID(customerId).String())) } var id string var renewVal int var refreshToken string var code string var mask string var attrsPerKey int var numbOfKeys int var alphaKey []byte var setKey []byte var passKey []byte var maskKey []byte var salt []byte var maxNKodeLen int var idxInterface []byte var svgIdInterface []byte err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface) if rows.Next() { return nil, errors.New(fmt.Sprintf("too many rows for user %s of customer %s", username, customerId)) } userId, err := uuid.Parse(id) if err != nil { return nil, err } var renew bool if renewVal == 0 { renew = false } else { renew = true } user := User{ Id: UserId(userId), CustomerId: customerId, Email: username, EncipheredPasscode: EncipheredNKode{ Code: code, Mask: mask, }, Kp: KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, CipherKeys: UserCipherKeys{ AlphaKey: util.ByteArrToUint64Arr(alphaKey), SetKey: util.ByteArrToUint64Arr(setKey), PassKey: util.ByteArrToUint64Arr(passKey), MaskKey: util.ByteArrToUint64Arr(maskKey), Salt: salt, MaxNKodeLen: maxNKodeLen, Kp: nil, }, Interface: UserInterface{ IdxInterface: util.ByteArrToIntArr(idxInterface), SvgId: util.ByteArrToIntArr(svgIdInterface), Kp: nil, }, Renew: renew, RefreshToken: refreshToken, } user.Interface.Kp = &user.Kp user.CipherKeys.Kp = &user.Kp return &user, nil } func (d *SqliteDB) UpdateUserInterface(id UserId, ui UserInterface) error { updateUserInterface := ` UPDATE user SET idx_interface = ? WHERE id = ? ` _, err := d.db.Exec(updateUserInterface, util.IntArrToByteArr(ui.IdxInterface), uuid.UUID(id).String()) return err } func (d *SqliteDB) UpdateUserRefreshToken(id UserId, refreshToken string) error { updateUserRefreshToken := ` UPDATE user SET refresh_token = ? WHERE id = ? ` _, err := d.db.Exec(updateUserRefreshToken, refreshToken, uuid.UUID(id).String()) return err } func (d *SqliteDB) Renew(id CustomerId) error { customer, err := d.GetCustomer(id) if err != nil { return err } setXor, attrXor := customer.RenewKeys() renewArgs := []any{util.Uint64ArrToByteArr(customer.Attributes.AttrVals), util.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()} // TODO: replace with tx renewExec := ` BEGIN TRANSACTION; UPDATE customer SET attribute_values = ?, set_values = ? WHERE id = ?; ` userQuery := ` SELECT id, alpha_key, set_key, attributes_per_key, number_of_keys FROM user WHERE customer_id = ? ` rows, err := d.db.Query(userQuery, uuid.UUID(id).String()) for rows.Next() { var userId string var alphaBytes []byte var setBytes []byte var attrsPerKey int var numbOfKeys int err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys) if err != nil { return err } user := User{ Id: UserId{}, CustomerId: CustomerId{}, Email: "", EncipheredPasscode: EncipheredNKode{}, Kp: KeypadDimension{ AttrsPerKey: attrsPerKey, NumbOfKeys: numbOfKeys, }, CipherKeys: UserCipherKeys{ AlphaKey: util.ByteArrToUint64Arr(alphaBytes), SetKey: util.ByteArrToUint64Arr(setBytes), }, Interface: UserInterface{}, Renew: false, } err = user.RenewKeys(setXor, attrXor) if err != nil { return err } renewExec += "\nUPDATE user SET alpha_key = ?, set_key = ?, renew = ? WHERE id = ?;" renewArgs = append(renewArgs, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId) } renewExec += ` COMMIT; ` _, err = d.db.Exec(renewExec, renewArgs...) return err } func (d *SqliteDB) RefreshUserPasscode(user User, passcodeIdx []int, customerAttr CustomerAttributes) error { err := user.RefreshPasscode(passcodeIdx, customerAttr) if err != nil { return err } updateUser := ` UPDATE user SET renew = ?, code = ?, mask = ?, alpha_key = ?, set_key = ?, pass_key = ?, mask_key = ?, salt = ? WHERE id = ?; ` _, err = d.db.Exec(updateUser, user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, util.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), util.Uint64ArrToByteArr(user.CipherKeys.SetKey), util.Uint64ArrToByteArr(user.CipherKeys.PassKey), util.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()) return err } func (d *SqliteDB) RandomSvgInterface(kp KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { return nil, err } return d.getSvgsById(ids) } func (d *SqliteDB) RandomSvgIdxInterface(kp KeypadDimension) (SvgIdInterface, error) { return d.getRandomIds(kp.TotalAttrs()) } func (d *SqliteDB) GetSvgStringInterface(idxs SvgIdInterface) ([]string, error) { return d.getSvgsById(idxs) } func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { selectId := "SELECT svg FROM svg_icon where id = ?" svgs := make([]string, len(ids)) for idx, id := range ids { rows, err := d.db.Query(selectId, id) if err != nil { return nil, err } if !rows.Next() { return nil, errors.New(fmt.Sprintf("id not found: %d", id)) } err = rows.Scan(&svgs[idx]) if err != nil { return nil, err } } return svgs, nil } func (d *SqliteDB) getRandomIds(count int) ([]int, error) { rows, err := d.db.Query("SELECT COUNT(*) as count FROM svg_icon;") if err != nil { return nil, err } var tableLen int if !rows.Next() { return nil, errors.New("empty svg_icon table") } err = rows.Scan(&tableLen) if err != nil { return nil, err } perm, err := util.RandomPermutation(tableLen) for idx := range perm { perm[idx] += 1 } if err != nil { return nil, err } return perm[:count], nil }