481 lines
14 KiB
Go
481 lines
14 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/google/uuid"
|
|
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
|
"go-nkode/config"
|
|
"go-nkode/internal/entities"
|
|
"go-nkode/internal/models"
|
|
"go-nkode/internal/security"
|
|
"go-nkode/internal/sqlc"
|
|
"go-nkode/internal/utils"
|
|
"log"
|
|
"sync"
|
|
)
|
|
|
|
const writeBufferSize = 100
|
|
|
|
type sqlcGeneric func(*sqlc.Queries, context.Context, any) error
|
|
|
|
// WriteTx represents a write transaction
|
|
type WriteTx struct {
|
|
ErrChan chan error
|
|
Query sqlcGeneric
|
|
Args interface{}
|
|
}
|
|
|
|
// SqliteDB represents the SQLite database connection and write queue
|
|
type SqliteDB struct {
|
|
queries *sqlc.Queries
|
|
db *sql.DB
|
|
writeQueue chan WriteTx
|
|
wg sync.WaitGroup
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewSqliteDB initializes a new SqliteDB instance
|
|
func NewSqliteDB(path string) (*SqliteDB, error) {
|
|
if path == "" {
|
|
return nil, errors.New("database path is required")
|
|
}
|
|
|
|
db, err := sql.Open("sqlite3", path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
sqldb := &SqliteDB{
|
|
queries: sqlc.New(db),
|
|
db: db,
|
|
writeQueue: make(chan WriteTx, writeBufferSize),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
sqldb.wg.Add(1)
|
|
go sqldb.processWriteQueue()
|
|
|
|
return sqldb, nil
|
|
}
|
|
|
|
// processWriteQueue handles write transactions from the queue
|
|
func (d *SqliteDB) processWriteQueue() {
|
|
defer d.wg.Done()
|
|
for {
|
|
select {
|
|
case <-d.ctx.Done():
|
|
return
|
|
case writeTx := <-d.writeQueue:
|
|
err := writeTx.Query(d.queries, d.ctx, writeTx.Args)
|
|
writeTx.ErrChan <- err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *SqliteDB) Close() error {
|
|
d.cancel()
|
|
d.wg.Wait()
|
|
close(d.writeQueue)
|
|
return d.db.Close()
|
|
}
|
|
|
|
func (d *SqliteDB) CreateCustomer(c entities.Customer) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.CreateCustomerParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected CreateCustomerParams")
|
|
}
|
|
return q.CreateCustomer(ctx, params)
|
|
}
|
|
|
|
return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams())
|
|
}
|
|
|
|
func (d *SqliteDB) WriteNewUser(u entities.User) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.CreateUserParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected CreateUserParams")
|
|
}
|
|
return q.CreateUser(ctx, params)
|
|
}
|
|
// Use the wrapped function in enqueueWriteTx
|
|
|
|
renew := 0
|
|
if u.Renew {
|
|
renew = 1
|
|
}
|
|
// Map entities.User to CreateUserParams
|
|
params := sqlc.CreateUserParams{
|
|
ID: uuid.UUID(u.Id).String(),
|
|
Email: string(u.Email),
|
|
Renew: int64(renew),
|
|
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
|
|
CustomerID: uuid.UUID(u.CustomerId).String(),
|
|
Code: u.EncipheredPasscode.Code,
|
|
Mask: u.EncipheredPasscode.Mask,
|
|
AttributesPerKey: int64(u.Kp.AttrsPerKey),
|
|
NumberOfKeys: int64(u.Kp.NumbOfKeys),
|
|
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
|
|
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
|
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
|
|
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
|
Salt: u.CipherKeys.Salt,
|
|
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
|
|
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
|
|
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
|
|
CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true},
|
|
}
|
|
return d.enqueueWriteTx(queryFunc, params)
|
|
}
|
|
|
|
func (d *SqliteDB) UpdateUserNKode(u entities.User) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.UpdateUserParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected UpdateUserParams")
|
|
}
|
|
return q.UpdateUser(ctx, params)
|
|
}
|
|
// Use the wrapped function in enqueueWriteTx
|
|
renew := 0
|
|
if u.Renew {
|
|
renew = 1
|
|
}
|
|
params := sqlc.UpdateUserParams{
|
|
Email: string(u.Email),
|
|
Renew: int64(renew),
|
|
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
|
|
CustomerID: uuid.UUID(u.CustomerId).String(),
|
|
Code: u.EncipheredPasscode.Code,
|
|
Mask: u.EncipheredPasscode.Mask,
|
|
AttributesPerKey: int64(u.Kp.AttrsPerKey),
|
|
NumberOfKeys: int64(u.Kp.NumbOfKeys),
|
|
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
|
|
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
|
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
|
|
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
|
Salt: u.CipherKeys.Salt,
|
|
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
|
|
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
|
|
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
|
|
}
|
|
return d.enqueueWriteTx(queryFunc, params)
|
|
}
|
|
|
|
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.UpdateUserInterfaceParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams")
|
|
}
|
|
return q.UpdateUserInterface(ctx, params)
|
|
}
|
|
params := sqlc.UpdateUserInterfaceParams{
|
|
IdxInterface: security.IntArrToByteArr(ui.IdxInterface),
|
|
LastLogin: utils.TimeStamp(),
|
|
ID: uuid.UUID(id).String(),
|
|
}
|
|
|
|
return d.enqueueWriteTx(queryFunc, params)
|
|
}
|
|
|
|
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.UpdateUserRefreshTokenParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken")
|
|
}
|
|
return q.UpdateUserRefreshToken(ctx, params)
|
|
}
|
|
params := sqlc.UpdateUserRefreshTokenParams{
|
|
RefreshToken: sql.NullString{
|
|
String: refreshToken,
|
|
Valid: true,
|
|
},
|
|
ID: uuid.UUID(id).String(),
|
|
}
|
|
return d.enqueueWriteTx(queryFunc, params)
|
|
}
|
|
|
|
func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error {
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.RenewCustomerParams)
|
|
if !ok {
|
|
|
|
}
|
|
return q.RenewCustomer(ctx, params)
|
|
}
|
|
return d.enqueueWriteTx(queryFunc, renewParams)
|
|
}
|
|
|
|
func (d *SqliteDB) Renew(id models.CustomerId) error {
|
|
setXor, attrXor, err := d.renewCustomer(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
customerId := models.CustomerIdToString(id)
|
|
userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.RenewUserParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected RenewUserParams")
|
|
}
|
|
return q.RenewUser(ctx, params)
|
|
}
|
|
|
|
for _, row := range userRenewRows {
|
|
user := entities.User{
|
|
Id: models.UserIdFromString(row.ID),
|
|
CustomerId: models.CustomerId{},
|
|
Email: "",
|
|
EncipheredPasscode: models.EncipheredNKode{},
|
|
Kp: entities.KeypadDimension{
|
|
AttrsPerKey: int(row.AttributesPerKey),
|
|
NumbOfKeys: int(row.NumberOfKeys),
|
|
},
|
|
CipherKeys: entities.UserCipherKeys{
|
|
AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey),
|
|
SetKey: security.ByteArrToUint64Arr(row.SetKey),
|
|
},
|
|
Interface: entities.UserInterface{},
|
|
Renew: false,
|
|
}
|
|
|
|
if err = user.RenewKeys(setXor, attrXor); err != nil {
|
|
return err
|
|
}
|
|
params := sqlc.RenewUserParams{
|
|
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
|
|
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
|
|
Renew: 1,
|
|
ID: uuid.UUID(user.Id).String(),
|
|
}
|
|
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) {
|
|
customer, err := d.GetCustomer(id)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
setXor, attrXor, err := customer.RenewKeys()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.RenewCustomerParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected RenewCustomerParams")
|
|
}
|
|
return q.RenewCustomer(ctx, params)
|
|
}
|
|
params := sqlc.RenewCustomerParams{
|
|
AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals),
|
|
SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals),
|
|
ID: uuid.UUID(customer.Id).String(),
|
|
}
|
|
|
|
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return setXor, attrXor, nil
|
|
}
|
|
|
|
func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
|
|
if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil {
|
|
return err
|
|
}
|
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
|
params, ok := args.(sqlc.RefreshUserPasscodeParams)
|
|
if !ok {
|
|
return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams")
|
|
}
|
|
return q.RefreshUserPasscode(ctx, params)
|
|
}
|
|
params := sqlc.RefreshUserPasscodeParams{
|
|
Renew: 0,
|
|
Code: user.EncipheredPasscode.Code,
|
|
Mask: user.EncipheredPasscode.Mask,
|
|
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
|
|
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
|
|
PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey),
|
|
MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey),
|
|
Salt: user.CipherKeys.Salt,
|
|
ID: uuid.UUID(user.Id).String(),
|
|
}
|
|
return d.enqueueWriteTx(queryFunc, params)
|
|
}
|
|
|
|
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
|
|
customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &entities.Customer{
|
|
Id: id,
|
|
NKodePolicy: models.NKodePolicy{
|
|
MaxNkodeLen: int(customer.MaxNkodeLen),
|
|
MinNkodeLen: int(customer.MinNkodeLen),
|
|
DistinctSets: int(customer.DistinctSets),
|
|
DistinctAttributes: int(customer.DistinctAttributes),
|
|
LockOut: int(customer.LockOut),
|
|
Expiration: int(customer.Expiration),
|
|
},
|
|
Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues),
|
|
}, nil
|
|
}
|
|
|
|
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
|
|
userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{
|
|
Email: string(email),
|
|
CustomerID: uuid.UUID(customerId).String(),
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
kp := entities.KeypadDimension{
|
|
AttrsPerKey: int(userRow.AttributesPerKey),
|
|
NumbOfKeys: int(userRow.NumberOfKeys),
|
|
}
|
|
|
|
renew := false
|
|
if userRow.Renew == 1 {
|
|
renew = true
|
|
}
|
|
user := entities.User{
|
|
Id: models.UserIdFromString(userRow.ID),
|
|
CustomerId: customerId,
|
|
Email: email,
|
|
EncipheredPasscode: models.EncipheredNKode{
|
|
Code: userRow.Code,
|
|
Mask: userRow.Mask,
|
|
},
|
|
Kp: kp,
|
|
CipherKeys: entities.UserCipherKeys{
|
|
AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey),
|
|
SetKey: security.ByteArrToUint64Arr(userRow.SetKey),
|
|
PassKey: security.ByteArrToUint64Arr(userRow.PassKey),
|
|
MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey),
|
|
Salt: userRow.Salt,
|
|
MaxNKodeLen: int(userRow.MaxNkodeLen),
|
|
Kp: &kp,
|
|
},
|
|
Interface: entities.UserInterface{
|
|
IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface),
|
|
SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface),
|
|
Kp: &kp,
|
|
},
|
|
Renew: renew,
|
|
RefreshToken: userRow.RefreshToken.String,
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) {
|
|
ids, err := d.getRandomIds(kp.TotalAttrs())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.getSvgsById(ids)
|
|
}
|
|
|
|
func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) {
|
|
return d.getRandomIds(kp.TotalAttrs())
|
|
}
|
|
|
|
func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) {
|
|
return d.getSvgsById(idxs)
|
|
}
|
|
|
|
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
|
svgs := make([]string, len(ids))
|
|
for idx, id := range ids {
|
|
svg, err := d.queries.GetSvgId(d.ctx, int64(id))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
svgs[idx] = svg
|
|
}
|
|
return svgs, nil
|
|
}
|
|
|
|
func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error {
|
|
select {
|
|
case <-d.ctx.Done():
|
|
return errors.New("database is shutting down")
|
|
default:
|
|
}
|
|
|
|
errChan := make(chan error, 1)
|
|
writeTx := WriteTx{
|
|
Query: queryFunc,
|
|
Args: args,
|
|
ErrChan: errChan,
|
|
}
|
|
d.writeQueue <- writeTx
|
|
return <-errChan
|
|
}
|
|
|
|
func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
|
tx, err := d.db.Begin()
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil, config.ErrSqliteTx
|
|
}
|
|
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil, config.ErrSqliteTx
|
|
}
|
|
var tableLen int
|
|
if !rows.Next() {
|
|
return nil, config.ErrEmptySvgTable
|
|
}
|
|
|
|
if err = rows.Scan(&tableLen); err != nil {
|
|
log.Print(err)
|
|
return nil, config.ErrSqliteTx
|
|
}
|
|
perm, err := security.RandomPermutation(tableLen)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for idx := range perm {
|
|
perm[idx] += 1
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
log.Print(err)
|
|
return nil, config.ErrSqliteTx
|
|
}
|
|
|
|
return perm[:count], nil
|
|
}
|