package db import ( "context" "database/sql" "errors" "fmt" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "go-nkode/config" "go-nkode/internal/entities" "go-nkode/internal/models" "go-nkode/internal/security" "go-nkode/internal/sqlc" "go-nkode/internal/utils" "log" "sync" ) const writeBufferSize = 100 type sqlcGeneric func(*sqlc.Queries, context.Context, any) error // WriteTx represents a write transaction type WriteTx struct { ErrChan chan error Query sqlcGeneric Args interface{} } // SqliteDB represents the SQLite database connection and write queue type SqliteDB struct { queries *sqlc.Queries db *sql.DB writeQueue chan WriteTx wg sync.WaitGroup ctx context.Context cancel context.CancelFunc } // NewSqliteDB initializes a new SqliteDB instance func NewSqliteDB(path string) (*SqliteDB, error) { if path == "" { return nil, errors.New("database path is required") } db, err := sql.Open("sqlite3", path) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } ctx, cancel := context.WithCancel(context.Background()) sqldb := &SqliteDB{ queries: sqlc.New(db), db: db, writeQueue: make(chan WriteTx, writeBufferSize), ctx: ctx, cancel: cancel, } sqldb.wg.Add(1) go sqldb.processWriteQueue() return sqldb, nil } // processWriteQueue handles write transactions from the queue func (d *SqliteDB) processWriteQueue() { defer d.wg.Done() for { select { case <-d.ctx.Done(): return case writeTx := <-d.writeQueue: err := writeTx.Query(d.queries, d.ctx, writeTx.Args) writeTx.ErrChan <- err } } } func (d *SqliteDB) Close() error { d.cancel() d.wg.Wait() close(d.writeQueue) return d.db.Close() } func (d *SqliteDB) CreateCustomer(c entities.Customer) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.CreateCustomerParams) if !ok { return fmt.Errorf("invalid argument type: expected CreateCustomerParams") } return q.CreateCustomer(ctx, params) } return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) } func (d *SqliteDB) WriteNewUser(u entities.User) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.CreateUserParams) if !ok { return fmt.Errorf("invalid argument type: expected CreateUserParams") } return q.CreateUser(ctx, params) } // Use the wrapped function in enqueueWriteTx renew := 0 if u.Renew { renew = 1 } // Map entities.User to CreateUserParams params := sqlc.CreateUserParams{ ID: uuid.UUID(u.Id).String(), Email: string(u.Email), Renew: int64(renew), RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""}, CustomerID: uuid.UUID(u.CustomerId).String(), Code: u.EncipheredPasscode.Code, Mask: u.EncipheredPasscode.Mask, AttributesPerKey: int64(u.Kp.AttrsPerKey), NumberOfKeys: int64(u.Kp.NumbOfKeys), AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey), PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey), MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), Salt: u.CipherKeys.Salt, MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen), IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true}, } return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserNKode(u entities.User) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserParams) if !ok { return fmt.Errorf("invalid argument type: expected UpdateUserParams") } return q.UpdateUser(ctx, params) } // Use the wrapped function in enqueueWriteTx renew := 0 if u.Renew { renew = 1 } params := sqlc.UpdateUserParams{ Email: string(u.Email), Renew: int64(renew), RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""}, CustomerID: uuid.UUID(u.CustomerId).String(), Code: u.EncipheredPasscode.Code, Mask: u.EncipheredPasscode.Mask, AttributesPerKey: int64(u.Kp.AttrsPerKey), NumberOfKeys: int64(u.Kp.NumbOfKeys), AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey), PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey), MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), Salt: u.CipherKeys.Salt, MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen), IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), } return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserInterfaceParams) if !ok { return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams") } return q.UpdateUserInterface(ctx, params) } params := sqlc.UpdateUserInterfaceParams{ IdxInterface: security.IntArrToByteArr(ui.IdxInterface), LastLogin: utils.TimeStamp(), ID: uuid.UUID(id).String(), } return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserRefreshTokenParams) if !ok { return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken") } return q.UpdateUserRefreshToken(ctx, params) } params := sqlc.UpdateUserRefreshTokenParams{ RefreshToken: sql.NullString{ String: refreshToken, Valid: true, }, ID: uuid.UUID(id).String(), } return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.RenewCustomerParams) if !ok { } return q.RenewCustomer(ctx, params) } return d.enqueueWriteTx(queryFunc, renewParams) } func (d *SqliteDB) Renew(id models.CustomerId) error { setXor, attrXor, err := d.renewCustomer(id) if err != nil { return err } customerId := models.CustomerIdToString(id) userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId) if err != nil { return err } queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.RenewUserParams) if !ok { return fmt.Errorf("invalid argument type: expected RenewUserParams") } return q.RenewUser(ctx, params) } for _, row := range userRenewRows { user := entities.User{ Id: models.UserIdFromString(row.ID), CustomerId: models.CustomerId{}, Email: "", EncipheredPasscode: models.EncipheredNKode{}, Kp: entities.KeypadDimension{ AttrsPerKey: int(row.AttributesPerKey), NumbOfKeys: int(row.NumberOfKeys), }, CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey), SetKey: security.ByteArrToUint64Arr(row.SetKey), }, Interface: entities.UserInterface{}, Renew: false, } if err = user.RenewKeys(setXor, attrXor); err != nil { return err } params := sqlc.RenewUserParams{ AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey), Renew: 1, ID: uuid.UUID(user.Id).String(), } if err = d.enqueueWriteTx(queryFunc, params); err != nil { return err } } return nil } func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) { customer, err := d.GetCustomer(id) if err != nil { return nil, nil, err } setXor, attrXor, err := customer.RenewKeys() if err != nil { return nil, nil, err } queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.RenewCustomerParams) if !ok { return fmt.Errorf("invalid argument type: expected RenewCustomerParams") } return q.RenewCustomer(ctx, params) } params := sqlc.RenewCustomerParams{ AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals), SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals), ID: uuid.UUID(customer.Id).String(), } if err = d.enqueueWriteTx(queryFunc, params); err != nil { return nil, nil, err } return setXor, attrXor, nil } func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil { return err } queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.RefreshUserPasscodeParams) if !ok { return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams") } return q.RefreshUserPasscode(ctx, params) } params := sqlc.RefreshUserPasscodeParams{ Renew: 0, Code: user.EncipheredPasscode.Code, Mask: user.EncipheredPasscode.Mask, AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey), PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey), MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), Salt: user.CipherKeys.Salt, ID: uuid.UUID(user.Id).String(), } return d.enqueueWriteTx(queryFunc, params) } func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String()) if err != nil { return nil, err } return &entities.Customer{ Id: id, NKodePolicy: models.NKodePolicy{ MaxNkodeLen: int(customer.MaxNkodeLen), MinNkodeLen: int(customer.MinNkodeLen), DistinctSets: int(customer.DistinctSets), DistinctAttributes: int(customer.DistinctAttributes), LockOut: int(customer.LockOut), Expiration: int(customer.Expiration), }, Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues), }, nil } func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{ Email: string(email), CustomerID: uuid.UUID(customerId).String(), }) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("failed to get user: %w", err) } kp := entities.KeypadDimension{ AttrsPerKey: int(userRow.AttributesPerKey), NumbOfKeys: int(userRow.NumberOfKeys), } renew := false if userRow.Renew == 1 { renew = true } user := entities.User{ Id: models.UserIdFromString(userRow.ID), CustomerId: customerId, Email: email, EncipheredPasscode: models.EncipheredNKode{ Code: userRow.Code, Mask: userRow.Mask, }, Kp: kp, CipherKeys: entities.UserCipherKeys{ AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey), SetKey: security.ByteArrToUint64Arr(userRow.SetKey), PassKey: security.ByteArrToUint64Arr(userRow.PassKey), MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey), Salt: userRow.Salt, MaxNKodeLen: int(userRow.MaxNkodeLen), Kp: &kp, }, Interface: entities.UserInterface{ IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface), SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface), Kp: &kp, }, Renew: renew, RefreshToken: userRow.RefreshToken.String, } return &user, nil } func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { return nil, err } return d.getSvgsById(ids) } func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { return d.getRandomIds(kp.TotalAttrs()) } func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) { return d.getSvgsById(idxs) } func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { svgs := make([]string, len(ids)) for idx, id := range ids { svg, err := d.queries.GetSvgId(d.ctx, int64(id)) if err != nil { return nil, err } svgs[idx] = svg } return svgs, nil } func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error { select { case <-d.ctx.Done(): return errors.New("database is shutting down") default: } errChan := make(chan error, 1) writeTx := WriteTx{ Query: queryFunc, Args: args, ErrChan: errChan, } d.writeQueue <- writeTx return <-errChan } func (d *SqliteDB) getRandomIds(count int) ([]int, error) { tx, err := d.db.Begin() if err != nil { log.Print(err) return nil, config.ErrSqliteTx } rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;") if err != nil { log.Print(err) return nil, config.ErrSqliteTx } var tableLen int if !rows.Next() { return nil, config.ErrEmptySvgTable } if err = rows.Scan(&tableLen); err != nil { log.Print(err) return nil, config.ErrSqliteTx } perm, err := security.RandomPermutation(tableLen) if err != nil { return nil, err } for idx := range perm { perm[idx] += 1 } if err = tx.Commit(); err != nil { log.Print(err) return nil, config.ErrSqliteTx } return perm[:count], nil }