refactor sqlite db to support sqlc
This commit is contained in:
@@ -43,15 +43,18 @@ func main() {
|
|||||||
if dbPath == "" {
|
if dbPath == "" {
|
||||||
log.Fatalf("SQLITE_DB=/path/to/nkode.db not set")
|
log.Fatalf("SQLITE_DB=/path/to/nkode.db not set")
|
||||||
}
|
}
|
||||||
db := db.NewSqliteDB(dbPath)
|
sqlitedb, err := db.NewSqliteDB(dbPath)
|
||||||
defer db.CloseDb()
|
if err != nil {
|
||||||
|
fmt.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
defer sqlitedb.Close()
|
||||||
|
|
||||||
sesClient := email.NewSESClient()
|
sesClient := email.NewSESClient()
|
||||||
emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient)
|
emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient)
|
||||||
emailQueue.Start()
|
emailQueue.Start()
|
||||||
defer emailQueue.Stop()
|
defer emailQueue.Stop()
|
||||||
|
|
||||||
nkodeApi := api.NewNKodeAPI(db, emailQueue)
|
nkodeApi := api.NewNKodeAPI(sqlitedb, emailQueue)
|
||||||
AddDefaultCustomer(nkodeApi)
|
AddDefaultCustomer(nkodeApi)
|
||||||
handler := api.NKodeHandler{Api: nkodeApi}
|
handler := api.NKodeHandler{Api: nkodeApi}
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = n.Db.WriteNewCustomer(*newCustomer)
|
err = n.Db.CreateCustomer(*newCustomer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ func TestNKodeAPI(t *testing.T) {
|
|||||||
|
|
||||||
dbFile := os.Getenv("TEST_DB")
|
dbFile := os.Getenv("TEST_DB")
|
||||||
|
|
||||||
db2 := db.NewSqliteDB(dbFile)
|
db2, err := db.NewSqliteDB(dbFile)
|
||||||
defer db2.CloseDb()
|
assert.NoError(t, err)
|
||||||
|
defer db2.Close()
|
||||||
testNKodeAPI(t, db2)
|
testNKodeAPI(t, db2)
|
||||||
|
|
||||||
//if _, err := os.Stat(dbFile); err == nil {
|
//if _, err := os.Stat(dbFile); err == nil {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
type CustomerUserRepository interface {
|
type CustomerUserRepository interface {
|
||||||
GetCustomer(models.CustomerId) (*entities.Customer, error)
|
GetCustomer(models.CustomerId) (*entities.Customer, error)
|
||||||
GetUser(models.UserEmail, models.CustomerId) (*entities.User, error)
|
GetUser(models.UserEmail, models.CustomerId) (*entities.User, error)
|
||||||
WriteNewCustomer(entities.Customer) error
|
CreateCustomer(entities.Customer) error
|
||||||
WriteNewUser(entities.User) error
|
WriteNewUser(entities.User) error
|
||||||
UpdateUserNKode(entities.User) error
|
UpdateUserNKode(entities.User) error
|
||||||
UpdateUserInterface(models.UserId, entities.UserInterface) error
|
UpdateUserInterface(models.UserId, entities.UserInterface) error
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *InMemoryDb) WriteNewCustomer(customer entities.Customer) error {
|
func (db *InMemoryDb) CreateCustomer(customer entities.Customer) error {
|
||||||
_, exists := db.Customers[customer.Id]
|
_, exists := db.Customers[customer.Id]
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
||||||
@@ -9,432 +11,387 @@ import (
|
|||||||
"go-nkode/internal/entities"
|
"go-nkode/internal/entities"
|
||||||
"go-nkode/internal/models"
|
"go-nkode/internal/models"
|
||||||
"go-nkode/internal/security"
|
"go-nkode/internal/security"
|
||||||
|
"go-nkode/internal/sqlc"
|
||||||
|
"go-nkode/internal/utils"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SqliteDB struct {
|
const writeBufferSize = 100
|
||||||
db *sql.DB
|
|
||||||
stop bool
|
|
||||||
writeQueue chan WriteTx
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
|
type sqlcGeneric func(*sqlc.Queries, context.Context, any) error
|
||||||
|
|
||||||
|
// WriteTx represents a write transaction
|
||||||
type WriteTx struct {
|
type WriteTx struct {
|
||||||
ErrChan chan error
|
ErrChan chan error
|
||||||
Query string
|
Query sqlcGeneric
|
||||||
Args []any
|
Args interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
// SqliteDB represents the SQLite database connection and write queue
|
||||||
writeBuffer = 1000
|
type SqliteDB struct {
|
||||||
)
|
queries *sqlc.Queries
|
||||||
|
db *sql.DB
|
||||||
|
writeQueue chan WriteTx
|
||||||
|
wg sync.WaitGroup
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSqliteDB initializes a new SqliteDB instance
|
||||||
|
func NewSqliteDB(path string) (*SqliteDB, error) {
|
||||||
|
if path == "" {
|
||||||
|
return nil, errors.New("database path is required")
|
||||||
|
}
|
||||||
|
|
||||||
func NewSqliteDB(path string) *SqliteDB {
|
|
||||||
db, err := sql.Open("sqlite3", path)
|
db, err := sql.Open("sqlite3", path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("database didn't open ", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
sqldb := SqliteDB{
|
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
sqldb := &SqliteDB{
|
||||||
|
queries: sqlc.New(db),
|
||||||
db: db,
|
db: db,
|
||||||
stop: false,
|
writeQueue: make(chan WriteTx, writeBufferSize),
|
||||||
writeQueue: make(chan WriteTx, writeBuffer),
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
sqldb.wg.Add(1)
|
||||||
for writeTx := range sqldb.writeQueue {
|
go sqldb.processWriteQueue()
|
||||||
writeTx.ErrChan <- sqldb.writeToDb(writeTx.Query, writeTx.Args)
|
|
||||||
sqldb.wg.Done()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return &sqldb
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) CloseDb() {
|
// processWriteQueue handles write transactions from the queue
|
||||||
d.stop = true
|
func (d *SqliteDB) processWriteQueue() {
|
||||||
|
defer d.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
return
|
||||||
|
case writeTx := <-d.writeQueue:
|
||||||
|
err := writeTx.Query(d.queries, d.ctx, writeTx.Args)
|
||||||
|
writeTx.ErrChan <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SqliteDB) Close() error {
|
||||||
|
d.cancel()
|
||||||
d.wg.Wait()
|
d.wg.Wait()
|
||||||
if err := d.db.Close(); err != nil {
|
close(d.writeQueue)
|
||||||
// If db.Close() returns an error, panic
|
return d.db.Close()
|
||||||
panic(fmt.Sprintf("Failed to close the database: %v", err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error {
|
func (d *SqliteDB) CreateCustomer(c entities.Customer) error {
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
INSERT INTO customer (
|
params, ok := args.(sqlc.CreateCustomerParams)
|
||||||
id
|
if !ok {
|
||||||
,max_nkode_len
|
return fmt.Errorf("invalid argument type: expected CreateCustomerParams")
|
||||||
,min_nkode_len
|
|
||||||
,distinct_sets
|
|
||||||
,distinct_attributes
|
|
||||||
,lock_out
|
|
||||||
,expiration
|
|
||||||
,attribute_values
|
|
||||||
,set_values
|
|
||||||
,last_renew
|
|
||||||
,created_at
|
|
||||||
)
|
|
||||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
|
||||||
`
|
|
||||||
args := []any{
|
|
||||||
uuid.UUID(c.Id), c.NKodePolicy.MaxNkodeLen, c.NKodePolicy.MinNkodeLen, c.NKodePolicy.DistinctSets,
|
|
||||||
c.NKodePolicy.DistinctAttributes, c.NKodePolicy.LockOut, c.NKodePolicy.Expiration,
|
|
||||||
c.Attributes.AttrBytes(), c.Attributes.SetBytes(), timeStamp(), timeStamp(),
|
|
||||||
}
|
}
|
||||||
return d.addWriteTx(query, args)
|
return q.CreateCustomer(ctx, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) WriteNewUser(u entities.User) error {
|
func (d *SqliteDB) WriteNewUser(u entities.User) error {
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
INSERT INTO user (
|
params, ok := args.(sqlc.CreateUserParams)
|
||||||
id
|
if !ok {
|
||||||
,email
|
return fmt.Errorf("invalid argument type: expected CreateUserParams")
|
||||||
,renew
|
}
|
||||||
,refresh_token
|
return q.CreateUser(ctx, params)
|
||||||
,customer_id
|
}
|
||||||
,code
|
// Use the wrapped function in enqueueWriteTx
|
||||||
,mask
|
|
||||||
,attributes_per_key
|
renew := 0
|
||||||
,number_of_keys
|
|
||||||
,alpha_key
|
|
||||||
,set_key
|
|
||||||
,pass_key
|
|
||||||
,mask_key
|
|
||||||
,salt
|
|
||||||
,max_nkode_len
|
|
||||||
,idx_interface
|
|
||||||
,svg_id_interface
|
|
||||||
,created_at
|
|
||||||
)
|
|
||||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
|
||||||
`
|
|
||||||
var renew int
|
|
||||||
if u.Renew {
|
if u.Renew {
|
||||||
renew = 1
|
renew = 1
|
||||||
} else {
|
|
||||||
renew = 0
|
|
||||||
}
|
}
|
||||||
|
// Map entities.User to CreateUserParams
|
||||||
args := []any{
|
params := sqlc.CreateUserParams{
|
||||||
uuid.UUID(u.Id), u.Email, renew, u.RefreshToken, uuid.UUID(u.CustomerId),
|
ID: uuid.UUID(u.Id).String(),
|
||||||
u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys,
|
Email: string(u.Email),
|
||||||
security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
Renew: int64(renew),
|
||||||
security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
|
||||||
u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface),
|
CustomerID: uuid.UUID(u.CustomerId).String(),
|
||||||
security.IntArrToByteArr(u.Interface.SvgId), timeStamp(),
|
Code: u.EncipheredPasscode.Code,
|
||||||
|
Mask: u.EncipheredPasscode.Mask,
|
||||||
|
AttributesPerKey: int64(u.Kp.AttrsPerKey),
|
||||||
|
NumberOfKeys: int64(u.Kp.NumbOfKeys),
|
||||||
|
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
|
||||||
|
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
||||||
|
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
|
||||||
|
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
||||||
|
Salt: u.CipherKeys.Salt,
|
||||||
|
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
|
||||||
|
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
|
||||||
|
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
|
||||||
|
CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true},
|
||||||
}
|
}
|
||||||
|
return d.enqueueWriteTx(queryFunc, params)
|
||||||
return d.addWriteTx(query, args)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) UpdateUserNKode(u entities.User) error {
|
func (d *SqliteDB) UpdateUserNKode(u entities.User) error {
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
UPDATE user
|
params, ok := args.(sqlc.UpdateUserParams)
|
||||||
SET renew = ?
|
if !ok {
|
||||||
,refresh_token = ?
|
return fmt.Errorf("invalid argument type: expected UpdateUserParams")
|
||||||
,code = ?
|
}
|
||||||
,mask = ?
|
return q.UpdateUser(ctx, params)
|
||||||
,attributes_per_key = ?
|
}
|
||||||
,number_of_keys = ?
|
// Use the wrapped function in enqueueWriteTx
|
||||||
,alpha_key = ?
|
renew := 0
|
||||||
,set_key = ?
|
|
||||||
,pass_key = ?
|
|
||||||
,mask_key = ?
|
|
||||||
,salt = ?
|
|
||||||
,max_nkode_len = ?
|
|
||||||
,idx_interface = ?
|
|
||||||
,svg_id_interface = ?
|
|
||||||
WHERE email = ? AND customer_id = ?
|
|
||||||
`
|
|
||||||
var renew int
|
|
||||||
if u.Renew {
|
if u.Renew {
|
||||||
renew = 1
|
renew = 1
|
||||||
} else {
|
|
||||||
renew = 0
|
|
||||||
}
|
}
|
||||||
args := []any{renew, u.RefreshToken, u.EncipheredPasscode.Code, u.EncipheredPasscode.Mask, u.Kp.AttrsPerKey, u.Kp.NumbOfKeys, security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(u.CipherKeys.SetKey), security.Uint64ArrToByteArr(u.CipherKeys.PassKey), security.Uint64ArrToByteArr(u.CipherKeys.MaskKey), u.CipherKeys.Salt, u.CipherKeys.MaxNKodeLen, security.IntArrToByteArr(u.Interface.IdxInterface), security.IntArrToByteArr(u.Interface.SvgId), string(u.Email), uuid.UUID(u.CustomerId)}
|
params := sqlc.UpdateUserParams{
|
||||||
|
Email: string(u.Email),
|
||||||
return d.addWriteTx(query, args)
|
Renew: int64(renew),
|
||||||
|
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
|
||||||
|
CustomerID: uuid.UUID(u.CustomerId).String(),
|
||||||
|
Code: u.EncipheredPasscode.Code,
|
||||||
|
Mask: u.EncipheredPasscode.Mask,
|
||||||
|
AttributesPerKey: int64(u.Kp.AttrsPerKey),
|
||||||
|
NumberOfKeys: int64(u.Kp.NumbOfKeys),
|
||||||
|
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
|
||||||
|
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
|
||||||
|
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
|
||||||
|
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
|
||||||
|
Salt: u.CipherKeys.Salt,
|
||||||
|
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
|
||||||
|
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
|
||||||
|
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
|
||||||
|
}
|
||||||
|
return d.enqueueWriteTx(queryFunc, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error {
|
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error {
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ?
|
params, ok := args.(sqlc.UpdateUserInterfaceParams)
|
||||||
`
|
if !ok {
|
||||||
args := []any{security.IntArrToByteArr(ui.IdxInterface), timeStamp(), uuid.UUID(id).String()}
|
return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams")
|
||||||
|
}
|
||||||
|
return q.UpdateUserInterface(ctx, params)
|
||||||
|
}
|
||||||
|
params := sqlc.UpdateUserInterfaceParams{
|
||||||
|
IdxInterface: security.IntArrToByteArr(ui.IdxInterface),
|
||||||
|
LastLogin: utils.TimeStamp(),
|
||||||
|
ID: uuid.UUID(id).String(),
|
||||||
|
}
|
||||||
|
|
||||||
return d.addWriteTx(query, args)
|
return d.enqueueWriteTx(queryFunc, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
|
func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error {
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
UPDATE user SET refresh_token = ? WHERE id = ?
|
params, ok := args.(sqlc.UpdateUserRefreshTokenParams)
|
||||||
`
|
if !ok {
|
||||||
args := []any{refreshToken, uuid.UUID(id).String()}
|
return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken")
|
||||||
|
}
|
||||||
|
return q.UpdateUserRefreshToken(ctx, params)
|
||||||
|
}
|
||||||
|
params := sqlc.UpdateUserRefreshTokenParams{
|
||||||
|
RefreshToken: sql.NullString{
|
||||||
|
String: refreshToken,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
ID: uuid.UUID(id).String(),
|
||||||
|
}
|
||||||
|
return d.enqueueWriteTx(queryFunc, params)
|
||||||
|
}
|
||||||
|
|
||||||
return d.addWriteTx(query, args)
|
func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error {
|
||||||
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
|
params, ok := args.(sqlc.RenewCustomerParams)
|
||||||
|
if !ok {
|
||||||
|
|
||||||
|
}
|
||||||
|
return q.RenewCustomer(ctx, params)
|
||||||
|
}
|
||||||
|
return d.enqueueWriteTx(queryFunc, renewParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) Renew(id models.CustomerId) error {
|
func (d *SqliteDB) Renew(id models.CustomerId) error {
|
||||||
// TODO: How long does a renew take?
|
setXor, attrXor, err := d.renewCustomer(id)
|
||||||
customer, err := d.GetCustomer(id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
setXor, attrXor, err := customer.RenewKeys()
|
customerId := models.CustomerIdToString(id)
|
||||||
|
userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
renewArgs := []any{security.Uint64ArrToByteArr(customer.Attributes.AttrVals), security.Uint64ArrToByteArr(customer.Attributes.SetVals), uuid.UUID(customer.Id).String()}
|
|
||||||
// TODO: replace with tx
|
|
||||||
renewQuery := `
|
|
||||||
UPDATE customer
|
|
||||||
SET attribute_values = ?, set_values = ?
|
|
||||||
WHERE id = ?;
|
|
||||||
`
|
|
||||||
|
|
||||||
userQuery := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
SELECT
|
params, ok := args.(sqlc.RenewUserParams)
|
||||||
id
|
if !ok {
|
||||||
,alpha_key
|
return fmt.Errorf("invalid argument type: expected RenewUserParams")
|
||||||
,set_key
|
|
||||||
,attributes_per_key
|
|
||||||
,number_of_keys
|
|
||||||
FROM user
|
|
||||||
WHERE customer_id = ?
|
|
||||||
`
|
|
||||||
tx, err := d.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
rows, err := tx.Query(userQuery, uuid.UUID(id).String())
|
return q.RenewUser(ctx, params)
|
||||||
for rows.Next() {
|
|
||||||
var userId string
|
|
||||||
var alphaBytes []byte
|
|
||||||
var setBytes []byte
|
|
||||||
var attrsPerKey int
|
|
||||||
var numbOfKeys int
|
|
||||||
err = rows.Scan(&userId, &alphaBytes, &setBytes, &attrsPerKey, &numbOfKeys)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, row := range userRenewRows {
|
||||||
user := entities.User{
|
user := entities.User{
|
||||||
Id: models.UserId{},
|
Id: models.UserIdFromString(row.ID),
|
||||||
CustomerId: models.CustomerId{},
|
CustomerId: models.CustomerId{},
|
||||||
Email: "",
|
Email: "",
|
||||||
EncipheredPasscode: models.EncipheredNKode{},
|
EncipheredPasscode: models.EncipheredNKode{},
|
||||||
Kp: entities.KeypadDimension{
|
Kp: entities.KeypadDimension{
|
||||||
AttrsPerKey: attrsPerKey,
|
AttrsPerKey: int(row.AttributesPerKey),
|
||||||
NumbOfKeys: numbOfKeys,
|
NumbOfKeys: int(row.NumberOfKeys),
|
||||||
},
|
},
|
||||||
CipherKeys: entities.UserCipherKeys{
|
CipherKeys: entities.UserCipherKeys{
|
||||||
AlphaKey: security.ByteArrToUint64Arr(alphaBytes),
|
AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey),
|
||||||
SetKey: security.ByteArrToUint64Arr(setBytes),
|
SetKey: security.ByteArrToUint64Arr(row.SetKey),
|
||||||
},
|
},
|
||||||
Interface: entities.UserInterface{},
|
Interface: entities.UserInterface{},
|
||||||
Renew: false,
|
Renew: false,
|
||||||
}
|
}
|
||||||
err = user.RenewKeys(setXor, attrXor)
|
|
||||||
if err != nil {
|
if err = user.RenewKeys(setXor, attrXor); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
renewQuery += `
|
params := sqlc.RenewUserParams{
|
||||||
UPDATE user
|
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
|
||||||
SET alpha_key = ?, set_key = ?, renew = ?
|
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
|
||||||
WHERE id = ?;
|
Renew: 1,
|
||||||
`
|
ID: uuid.UUID(user.Id).String(),
|
||||||
renewArgs = append(renewArgs, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), 1, userId)
|
|
||||||
}
|
}
|
||||||
renewQuery += `
|
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
|
||||||
`
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return d.addWriteTx(renewQuery, renewArgs)
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) {
|
||||||
|
customer, err := d.GetCustomer(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
setXor, attrXor, err := customer.RenewKeys()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
|
params, ok := args.(sqlc.RenewCustomerParams)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid argument type: expected RenewCustomerParams")
|
||||||
|
}
|
||||||
|
return q.RenewCustomer(ctx, params)
|
||||||
|
}
|
||||||
|
params := sqlc.RenewCustomerParams{
|
||||||
|
AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals),
|
||||||
|
SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals),
|
||||||
|
ID: uuid.UUID(customer.Id).String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = d.enqueueWriteTx(queryFunc, params); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return setXor, attrXor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
|
func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
|
||||||
err := user.RefreshPasscode(passcodeIdx, customerAttr)
|
if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
query := `
|
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
|
||||||
UPDATE user
|
params, ok := args.(sqlc.RefreshUserPasscodeParams)
|
||||||
SET
|
if !ok {
|
||||||
renew = ?
|
return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams")
|
||||||
,code = ?
|
}
|
||||||
,mask = ?
|
return q.RefreshUserPasscode(ctx, params)
|
||||||
,alpha_key = ?
|
}
|
||||||
,set_key = ?
|
params := sqlc.RefreshUserPasscodeParams{
|
||||||
,pass_key = ?
|
Renew: 0,
|
||||||
,mask_key = ?
|
Code: user.EncipheredPasscode.Code,
|
||||||
,salt = ?
|
Mask: user.EncipheredPasscode.Mask,
|
||||||
WHERE id = ?;
|
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
|
||||||
`
|
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
|
||||||
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
|
PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey),
|
||||||
return d.addWriteTx(query, args)
|
MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey),
|
||||||
|
Salt: user.CipherKeys.Salt,
|
||||||
|
ID: uuid.UUID(user.Id).String(),
|
||||||
|
}
|
||||||
|
return d.enqueueWriteTx(queryFunc, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
|
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
|
||||||
tx, err := d.db.Begin()
|
customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = tx.Rollback()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(fmt.Sprintf("Write new user won't roll back %+v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
selectCustomer := `
|
|
||||||
SELECT
|
|
||||||
max_nkode_len
|
|
||||||
,min_nkode_len
|
|
||||||
,distinct_sets
|
|
||||||
,distinct_attributes
|
|
||||||
,lock_out
|
|
||||||
,expiration
|
|
||||||
,attribute_values
|
|
||||||
,set_values
|
|
||||||
FROM customer
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
rows, err := tx.Query(selectCustomer, uuid.UUID(id))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !rows.Next() {
|
return &entities.Customer{
|
||||||
log.Printf("no new row for customer %s with err %s", id, rows.Err())
|
|
||||||
return nil, config.ErrCustomerDne
|
|
||||||
}
|
|
||||||
|
|
||||||
var maxNKodeLen int
|
|
||||||
var minNKodeLen int
|
|
||||||
var distinctSets int
|
|
||||||
var distinctAttributes int
|
|
||||||
var lockOut int
|
|
||||||
var expiration int
|
|
||||||
var attributeValues []byte
|
|
||||||
var setValues []byte
|
|
||||||
err = rows.Scan(&maxNKodeLen, &minNKodeLen, &distinctSets, &distinctAttributes, &lockOut, &expiration, &attributeValues, &setValues)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
customer := entities.Customer{
|
|
||||||
Id: id,
|
Id: id,
|
||||||
NKodePolicy: models.NKodePolicy{
|
NKodePolicy: models.NKodePolicy{
|
||||||
MaxNkodeLen: maxNKodeLen,
|
MaxNkodeLen: int(customer.MaxNkodeLen),
|
||||||
MinNkodeLen: minNKodeLen,
|
MinNkodeLen: int(customer.MinNkodeLen),
|
||||||
DistinctSets: distinctSets,
|
DistinctSets: int(customer.DistinctSets),
|
||||||
DistinctAttributes: distinctAttributes,
|
DistinctAttributes: int(customer.DistinctAttributes),
|
||||||
LockOut: lockOut,
|
LockOut: int(customer.LockOut),
|
||||||
Expiration: expiration,
|
Expiration: int(customer.Expiration),
|
||||||
},
|
},
|
||||||
Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues),
|
Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues),
|
||||||
}
|
}, nil
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &customer, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
|
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
|
||||||
tx, err := d.db.Begin()
|
userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{
|
||||||
|
Email: string(email),
|
||||||
|
CustomerID: uuid.UUID(customerId).String(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
}
|
|
||||||
userSelect := `
|
|
||||||
SELECT
|
|
||||||
id
|
|
||||||
,renew
|
|
||||||
,refresh_token
|
|
||||||
,code
|
|
||||||
,mask
|
|
||||||
,attributes_per_key
|
|
||||||
,number_of_keys
|
|
||||||
,alpha_key
|
|
||||||
,set_key
|
|
||||||
,pass_key
|
|
||||||
,mask_key
|
|
||||||
,salt
|
|
||||||
,max_nkode_len
|
|
||||||
,idx_interface
|
|
||||||
,svg_id_interface
|
|
||||||
FROM user
|
|
||||||
WHERE user.email = ? AND user.customer_id = ?
|
|
||||||
`
|
|
||||||
rows, err := tx.Query(userSelect, string(email), uuid.UUID(customerId).String())
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
var (
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||||
id string
|
|
||||||
renewVal int
|
|
||||||
refreshToken string
|
|
||||||
code string
|
|
||||||
mask string
|
|
||||||
attrsPerKey int
|
|
||||||
numbOfKeys int
|
|
||||||
alphaKey []byte
|
|
||||||
setKey []byte
|
|
||||||
passKey []byte
|
|
||||||
maskKey []byte
|
|
||||||
salt []byte
|
|
||||||
maxNKodeLen int
|
|
||||||
idxInterface []byte
|
|
||||||
svgIdInterface []byte
|
|
||||||
)
|
|
||||||
err = rows.Scan(&id, &renewVal, &refreshToken, &code, &mask, &attrsPerKey, &numbOfKeys, &alphaKey, &setKey, &passKey, &maskKey, &salt, &maxNKodeLen, &idxInterface, &svgIdInterface)
|
|
||||||
|
|
||||||
userId, err := uuid.Parse(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
var renew bool
|
|
||||||
if renewVal == 0 {
|
kp := entities.KeypadDimension{
|
||||||
renew = false
|
AttrsPerKey: int(userRow.AttributesPerKey),
|
||||||
} else {
|
NumbOfKeys: int(userRow.NumberOfKeys),
|
||||||
|
}
|
||||||
|
|
||||||
|
renew := false
|
||||||
|
if userRow.Renew == 1 {
|
||||||
renew = true
|
renew = true
|
||||||
}
|
}
|
||||||
|
|
||||||
user := entities.User{
|
user := entities.User{
|
||||||
Id: models.UserId(userId),
|
Id: models.UserIdFromString(userRow.ID),
|
||||||
CustomerId: customerId,
|
CustomerId: customerId,
|
||||||
Email: email,
|
Email: email,
|
||||||
EncipheredPasscode: models.EncipheredNKode{
|
EncipheredPasscode: models.EncipheredNKode{
|
||||||
Code: code,
|
Code: userRow.Code,
|
||||||
Mask: mask,
|
Mask: userRow.Mask,
|
||||||
},
|
|
||||||
Kp: entities.KeypadDimension{
|
|
||||||
AttrsPerKey: attrsPerKey,
|
|
||||||
NumbOfKeys: numbOfKeys,
|
|
||||||
},
|
},
|
||||||
|
Kp: kp,
|
||||||
CipherKeys: entities.UserCipherKeys{
|
CipherKeys: entities.UserCipherKeys{
|
||||||
AlphaKey: security.ByteArrToUint64Arr(alphaKey),
|
AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey),
|
||||||
SetKey: security.ByteArrToUint64Arr(setKey),
|
SetKey: security.ByteArrToUint64Arr(userRow.SetKey),
|
||||||
PassKey: security.ByteArrToUint64Arr(passKey),
|
PassKey: security.ByteArrToUint64Arr(userRow.PassKey),
|
||||||
MaskKey: security.ByteArrToUint64Arr(maskKey),
|
MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey),
|
||||||
Salt: salt,
|
Salt: userRow.Salt,
|
||||||
MaxNKodeLen: maxNKodeLen,
|
MaxNKodeLen: int(userRow.MaxNkodeLen),
|
||||||
Kp: nil,
|
Kp: &kp,
|
||||||
},
|
},
|
||||||
Interface: entities.UserInterface{
|
Interface: entities.UserInterface{
|
||||||
IdxInterface: security.ByteArrToIntArr(idxInterface),
|
IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface),
|
||||||
SvgId: security.ByteArrToIntArr(svgIdInterface),
|
SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface),
|
||||||
Kp: nil,
|
Kp: &kp,
|
||||||
},
|
},
|
||||||
Renew: renew,
|
Renew: renew,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: userRow.RefreshToken.String,
|
||||||
}
|
|
||||||
user.Interface.Kp = &user.Kp
|
|
||||||
user.CipherKeys.Kp = &user.Kp
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
@@ -456,68 +413,30 @@ func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) {
|
||||||
tx, err := d.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
selectId := `
|
|
||||||
SELECT svg
|
|
||||||
FROM svg_icon
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
svgs := make([]string, len(ids))
|
svgs := make([]string, len(ids))
|
||||||
for idx, id := range ids {
|
for idx, id := range ids {
|
||||||
rows, err := tx.Query(selectId, id)
|
svg, err := d.queries.GetSvgId(d.ctx, int64(id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !rows.Next() {
|
svgs[idx] = svg
|
||||||
log.Printf("id not found: %d", id)
|
|
||||||
return nil, config.ErrSvgDne
|
|
||||||
}
|
|
||||||
if err = rows.Scan(&svgs[idx]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return svgs, nil
|
return svgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SqliteDB) writeToDb(query string, args []any) error {
|
func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error {
|
||||||
tx, err := d.db.Begin()
|
select {
|
||||||
if err != nil {
|
case <-d.ctx.Done():
|
||||||
return err
|
return errors.New("database is shutting down")
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = tx.Rollback()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("fatal error: write won't roll back %+v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if _, err = tx.Exec(query, args...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *SqliteDB) addWriteTx(query string, args []any) error {
|
errChan := make(chan error, 1)
|
||||||
if d.stop {
|
|
||||||
return config.ErrStoppingDatabase
|
|
||||||
}
|
|
||||||
errChan := make(chan error)
|
|
||||||
writeTx := WriteTx{
|
writeTx := WriteTx{
|
||||||
Query: query,
|
Query: queryFunc,
|
||||||
Args: args,
|
Args: args,
|
||||||
ErrChan: errChan,
|
ErrChan: errChan,
|
||||||
}
|
}
|
||||||
d.wg.Add(1)
|
|
||||||
d.writeQueue <- writeTx
|
d.writeQueue <- writeTx
|
||||||
return <-errChan
|
return <-errChan
|
||||||
}
|
}
|
||||||
@@ -559,7 +478,3 @@ func (d *SqliteDB) getRandomIds(count int) ([]int, error) {
|
|||||||
|
|
||||||
return perm[:count], nil
|
return perm[:count], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func timeStamp() string {
|
|
||||||
return time.Now().Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ import (
|
|||||||
func TestNewSqliteDB(t *testing.T) {
|
func TestNewSqliteDB(t *testing.T) {
|
||||||
dbFile := os.Getenv("TEST_DB")
|
dbFile := os.Getenv("TEST_DB")
|
||||||
// sql_driver.MakeTables(dbFile)
|
// sql_driver.MakeTables(dbFile)
|
||||||
db := NewSqliteDB(dbFile)
|
db, err := NewSqliteDB(dbFile)
|
||||||
defer db.CloseDb()
|
assert.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
testSignupLoginRenew(t, db)
|
testSignupLoginRenew(t, db)
|
||||||
testSqliteDBRandomSvgInterface(t, db)
|
testSqliteDBRandomSvgInterface(t, db)
|
||||||
@@ -28,7 +29,7 @@ func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) {
|
|||||||
nkodePolicy := models.NewDefaultNKodePolicy()
|
nkodePolicy := models.NewDefaultNKodePolicy()
|
||||||
customerOrig, err := entities.NewCustomer(nkodePolicy)
|
customerOrig, err := entities.NewCustomer(nkodePolicy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = db.WriteNewCustomer(*customerOrig)
|
err = db.CreateCustomer(*customerOrig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
customer, err := db.GetCustomer(customerOrig.Id)
|
customer, err := db.GetCustomer(customerOrig.Id)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
queue.AddEmail(email)
|
queue.AddEmail(email)
|
||||||
}
|
}
|
||||||
// CloseDb the queue after all emails are processed
|
// Close the queue after all emails are processed
|
||||||
queue.Stop()
|
queue.Stop()
|
||||||
|
|
||||||
assert.Equal(t, queue.FailedSendCount, 0)
|
assert.Equal(t, queue.FailedSendCount, 0)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"go-nkode/config"
|
"go-nkode/config"
|
||||||
"go-nkode/internal/models"
|
"go-nkode/internal/models"
|
||||||
"go-nkode/internal/security"
|
"go-nkode/internal/security"
|
||||||
|
"go-nkode/internal/sqlc"
|
||||||
"go-nkode/internal/utils"
|
"go-nkode/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,3 +84,19 @@ func (c *Customer) RenewKeys() ([]uint64, []uint64, error) {
|
|||||||
}
|
}
|
||||||
return setXor, attrsXor, nil
|
return setXor, attrsXor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Customer) ToSqlcCreateCustomerParams() sqlc.CreateCustomerParams {
|
||||||
|
return sqlc.CreateCustomerParams{
|
||||||
|
ID: uuid.UUID(c.Id).String(),
|
||||||
|
MaxNkodeLen: int64(c.NKodePolicy.MaxNkodeLen),
|
||||||
|
MinNkodeLen: int64(c.NKodePolicy.MinNkodeLen),
|
||||||
|
DistinctSets: int64(c.NKodePolicy.DistinctSets),
|
||||||
|
DistinctAttributes: int64(c.NKodePolicy.DistinctAttributes),
|
||||||
|
LockOut: int64(c.NKodePolicy.LockOut),
|
||||||
|
Expiration: int64(c.NKodePolicy.Expiration),
|
||||||
|
AttributeValues: c.Attributes.AttrBytes(),
|
||||||
|
SetValues: c.Attributes.SetBytes(),
|
||||||
|
LastRenew: utils.TimeStamp(),
|
||||||
|
CreatedAt: utils.TimeStamp(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -37,11 +37,13 @@ func (u *User) RenewKeys(setXor []uint64, attrXor []uint64) error {
|
|||||||
|
|
||||||
func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error {
|
func (u *User) RefreshPasscode(passcodeAttrIdx []int, customerAttributes CustomerAttributes) error {
|
||||||
setVals, err := customerAttributes.SetValsForKp(u.Kp)
|
setVals, err := customerAttributes.SetValsForKp(u.Kp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen)
|
newKeys, err := NewUserCipherKeys(&u.Kp, setVals, u.CipherKeys.MaxNKodeLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
encipheredPasscode, err := newKeys.EncipherNKode(passcodeAttrIdx, customerAttributes)
|
encipheredPasscode, err := newKeys.EncipherNKode(passcodeAttrIdx, customerAttributes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -99,10 +100,17 @@ func CustomerIdToString(customerId CustomerId) string {
|
|||||||
type SessionId uuid.UUID
|
type SessionId uuid.UUID
|
||||||
type UserId uuid.UUID
|
type UserId uuid.UUID
|
||||||
|
|
||||||
|
func UserIdFromString(userId string) UserId {
|
||||||
|
id, err := uuid.Parse(userId)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Errorf("unable to parse user id %+v", err)
|
||||||
|
}
|
||||||
|
return UserId(id)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SessionId) String() string {
|
func (s *SessionId) String() string {
|
||||||
id := uuid.UUID(*s)
|
id := uuid.UUID(*s)
|
||||||
return id.String()
|
return id.String()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserEmail string
|
type UserEmail string
|
||||||
|
|||||||
7
internal/utils/timestamp.go
Normal file
7
internal/utils/timestamp.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
func TimeStamp() string {
|
||||||
|
return time.Now().Format(time.RFC3339)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user