migrate nkode-core

This commit is contained in:
2025-01-21 13:18:46 -06:00
parent 4dbb4c48c8
commit 1f10af0081
38 changed files with 4167 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
package repository
import (
"git.infra.nkode.tech/dkelly/nkode-core/entities"
)
type CustomerUserRepository interface {
GetCustomer(entities.CustomerId) (*entities.Customer, error)
GetUser(entities.UserEmail, entities.CustomerId) (*entities.User, error)
CreateCustomer(entities.Customer) error
WriteNewUser(entities.User) error
UpdateUserNKode(entities.User) error
UpdateUserInterface(entities.UserId, entities.UserInterface) error
UpdateUserRefreshToken(entities.UserId, string) error
Renew(entities.CustomerId) error
RefreshUserPasscode(entities.User, []int, entities.CustomerAttributes) error
RandomSvgInterface(entities.KeypadDimension) ([]string, error)
RandomSvgIdxInterface(entities.KeypadDimension) (entities.SvgIdInterface, error)
GetSvgStringInterface(entities.SvgIdInterface) ([]string, error)
}

View File

@@ -0,0 +1,401 @@
package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"git.infra.nkode.tech/dkelly/nkode-core/config"
"git.infra.nkode.tech/dkelly/nkode-core/entities"
"git.infra.nkode.tech/dkelly/nkode-core/security"
"git.infra.nkode.tech/dkelly/nkode-core/sqlc"
"git.infra.nkode.tech/dkelly/nkode-core/utils"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"log"
)
type SqliteRepository struct {
Queue *sqlc.Queue
ctx context.Context
}
func NewSqliteRepository(queue *sqlc.Queue, ctx context.Context) SqliteRepository {
return SqliteRepository{
Queue: queue,
ctx: ctx,
}
}
func (d *SqliteRepository) CreateCustomer(c entities.Customer) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.CreateCustomerParams)
if !ok {
return fmt.Errorf("invalid argument type: expected CreateCustomerParams")
}
return q.CreateCustomer(ctx, params)
}
return d.Queue.EnqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams())
}
func (d *SqliteRepository) WriteNewUser(u entities.User) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.CreateUserParams)
if !ok {
return fmt.Errorf("invalid argument type: expected CreateUserParams")
}
return q.CreateUser(ctx, params)
}
// Use the wrapped function in EnqueueWriteTx
renew := 0
if u.Renew {
renew = 1
}
// Map entities.User to CreateUserParams
params := sqlc.CreateUserParams{
ID: uuid.UUID(u.Id).String(),
Email: string(u.Email),
Renew: int64(renew),
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
CustomerID: uuid.UUID(u.CustomerId).String(),
Code: u.EncipheredPasscode.Code,
Mask: u.EncipheredPasscode.Mask,
AttributesPerKey: int64(u.Kp.AttrsPerKey),
NumberOfKeys: int64(u.Kp.NumbOfKeys),
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
Salt: u.CipherKeys.Salt,
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true},
}
return d.Queue.EnqueueWriteTx(queryFunc, params)
}
func (d *SqliteRepository) UpdateUserNKode(u entities.User) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.UpdateUserParams)
if !ok {
return fmt.Errorf("invalid argument type: expected UpdateUserParams")
}
return q.UpdateUser(ctx, params)
}
// Use the wrapped function in EnqueueWriteTx
renew := 0
if u.Renew {
renew = 1
}
params := sqlc.UpdateUserParams{
Email: string(u.Email),
Renew: int64(renew),
RefreshToken: sql.NullString{String: u.RefreshToken, Valid: u.RefreshToken != ""},
CustomerID: uuid.UUID(u.CustomerId).String(),
Code: u.EncipheredPasscode.Code,
Mask: u.EncipheredPasscode.Mask,
AttributesPerKey: int64(u.Kp.AttrsPerKey),
NumberOfKeys: int64(u.Kp.NumbOfKeys),
AlphaKey: security.Uint64ArrToByteArr(u.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(u.CipherKeys.SetKey),
PassKey: security.Uint64ArrToByteArr(u.CipherKeys.PassKey),
MaskKey: security.Uint64ArrToByteArr(u.CipherKeys.MaskKey),
Salt: u.CipherKeys.Salt,
MaxNkodeLen: int64(u.CipherKeys.MaxNKodeLen),
IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface),
SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId),
}
return d.Queue.EnqueueWriteTx(queryFunc, params)
}
func (d *SqliteRepository) UpdateUserInterface(id entities.UserId, ui entities.UserInterface) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.UpdateUserInterfaceParams)
if !ok {
return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams")
}
return q.UpdateUserInterface(ctx, params)
}
params := sqlc.UpdateUserInterfaceParams{
IdxInterface: security.IntArrToByteArr(ui.IdxInterface),
LastLogin: utils.TimeStamp(),
ID: uuid.UUID(id).String(),
}
return d.Queue.EnqueueWriteTx(queryFunc, params)
}
func (d *SqliteRepository) UpdateUserRefreshToken(id entities.UserId, refreshToken string) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.UpdateUserRefreshTokenParams)
if !ok {
return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken")
}
return q.UpdateUserRefreshToken(ctx, params)
}
params := sqlc.UpdateUserRefreshTokenParams{
RefreshToken: sql.NullString{
String: refreshToken,
Valid: true,
},
ID: uuid.UUID(id).String(),
}
return d.Queue.EnqueueWriteTx(queryFunc, params)
}
func (d *SqliteRepository) RenewCustomer(renewParams sqlc.RenewCustomerParams) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RenewCustomerParams)
if !ok {
}
return q.RenewCustomer(ctx, params)
}
return d.Queue.EnqueueWriteTx(queryFunc, renewParams)
}
func (d *SqliteRepository) Renew(id entities.CustomerId) error {
setXor, attrXor, err := d.renewCustomer(id)
if err != nil {
return err
}
customerId := entities.CustomerIdToString(id)
userRenewRows, err := d.Queue.Queries.GetUserRenew(d.ctx, customerId)
if err != nil {
return err
}
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RenewUserParams)
if !ok {
return fmt.Errorf("invalid argument type: expected RenewUserParams")
}
return q.RenewUser(ctx, params)
}
for _, row := range userRenewRows {
user := entities.User{
Id: entities.UserIdFromString(row.ID),
CustomerId: entities.CustomerId{},
Email: "",
EncipheredPasscode: entities.EncipheredNKode{},
Kp: entities.KeypadDimension{
AttrsPerKey: int(row.AttributesPerKey),
NumbOfKeys: int(row.NumberOfKeys),
},
CipherKeys: entities.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(row.AlphaKey),
SetKey: security.ByteArrToUint64Arr(row.SetKey),
},
Interface: entities.UserInterface{},
Renew: false,
}
if err = user.RenewKeys(setXor, attrXor); err != nil {
return err
}
params := sqlc.RenewUserParams{
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
Renew: 1,
ID: uuid.UUID(user.Id).String(),
}
if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil {
return err
}
}
return nil
}
func (d *SqliteRepository) renewCustomer(id entities.CustomerId) ([]uint64, []uint64, error) {
customer, err := d.GetCustomer(id)
if err != nil {
return nil, nil, err
}
setXor, attrXor, err := customer.RenewKeys()
if err != nil {
return nil, nil, err
}
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RenewCustomerParams)
if !ok {
return fmt.Errorf("invalid argument type: expected RenewCustomerParams")
}
return q.RenewCustomer(ctx, params)
}
params := sqlc.RenewCustomerParams{
AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals),
SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals),
ID: uuid.UUID(customer.Id).String(),
}
if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil {
return nil, nil, err
}
return setXor, attrXor, nil
}
func (d *SqliteRepository) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil {
return err
}
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(sqlc.RefreshUserPasscodeParams)
if !ok {
return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams")
}
return q.RefreshUserPasscode(ctx, params)
}
params := sqlc.RefreshUserPasscodeParams{
Renew: 0,
Code: user.EncipheredPasscode.Code,
Mask: user.EncipheredPasscode.Mask,
AlphaKey: security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey),
SetKey: security.Uint64ArrToByteArr(user.CipherKeys.SetKey),
PassKey: security.Uint64ArrToByteArr(user.CipherKeys.PassKey),
MaskKey: security.Uint64ArrToByteArr(user.CipherKeys.MaskKey),
Salt: user.CipherKeys.Salt,
ID: uuid.UUID(user.Id).String(),
}
return d.Queue.EnqueueWriteTx(queryFunc, params)
}
func (d *SqliteRepository) GetCustomer(id entities.CustomerId) (*entities.Customer, error) {
customer, err := d.Queue.Queries.GetCustomer(d.ctx, uuid.UUID(id).String())
if err != nil {
return nil, err
}
return &entities.Customer{
Id: id,
NKodePolicy: entities.NKodePolicy{
MaxNkodeLen: int(customer.MaxNkodeLen),
MinNkodeLen: int(customer.MinNkodeLen),
DistinctSets: int(customer.DistinctSets),
DistinctAttributes: int(customer.DistinctAttributes),
LockOut: int(customer.LockOut),
Expiration: int(customer.Expiration),
},
Attributes: entities.NewCustomerAttributesFromBytes(customer.AttributeValues, customer.SetValues),
}, nil
}
func (d *SqliteRepository) GetUser(email entities.UserEmail, customerId entities.CustomerId) (*entities.User, error) {
userRow, err := d.Queue.Queries.GetUser(d.ctx, sqlc.GetUserParams{
Email: string(email),
CustomerID: uuid.UUID(customerId).String(),
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to get user: %w", err)
}
kp := entities.KeypadDimension{
AttrsPerKey: int(userRow.AttributesPerKey),
NumbOfKeys: int(userRow.NumberOfKeys),
}
renew := false
if userRow.Renew == 1 {
renew = true
}
user := entities.User{
Id: entities.UserIdFromString(userRow.ID),
CustomerId: customerId,
Email: email,
EncipheredPasscode: entities.EncipheredNKode{
Code: userRow.Code,
Mask: userRow.Mask,
},
Kp: kp,
CipherKeys: entities.UserCipherKeys{
AlphaKey: security.ByteArrToUint64Arr(userRow.AlphaKey),
SetKey: security.ByteArrToUint64Arr(userRow.SetKey),
PassKey: security.ByteArrToUint64Arr(userRow.PassKey),
MaskKey: security.ByteArrToUint64Arr(userRow.MaskKey),
Salt: userRow.Salt,
MaxNKodeLen: int(userRow.MaxNkodeLen),
Kp: &kp,
},
Interface: entities.UserInterface{
IdxInterface: security.ByteArrToIntArr(userRow.IdxInterface),
SvgId: security.ByteArrToIntArr(userRow.SvgIDInterface),
Kp: &kp,
},
Renew: renew,
RefreshToken: userRow.RefreshToken.String,
}
return &user, nil
}
func (d *SqliteRepository) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) {
ids, err := d.getRandomIds(kp.TotalAttrs())
if err != nil {
return nil, err
}
return d.getSvgsById(ids)
}
func (d *SqliteRepository) RandomSvgIdxInterface(kp entities.KeypadDimension) (entities.SvgIdInterface, error) {
return d.getRandomIds(kp.TotalAttrs())
}
func (d *SqliteRepository) GetSvgStringInterface(idxs entities.SvgIdInterface) ([]string, error) {
return d.getSvgsById(idxs)
}
func (d *SqliteRepository) getSvgsById(ids []int) ([]string, error) {
svgs := make([]string, len(ids))
for idx, id := range ids {
svg, err := d.Queue.Queries.GetSvgId(d.ctx, int64(id))
if err != nil {
return nil, err
}
svgs[idx] = svg
}
return svgs, nil
}
func (d *SqliteRepository) getRandomIds(count int) ([]int, error) {
tx, err := d.Queue.Db.Begin()
if err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
rows, err := tx.Query("SELECT COUNT(*) as count FROM svg_icon;")
if err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
var tableLen int
if !rows.Next() {
return nil, config.ErrEmptySvgTable
}
if err = rows.Scan(&tableLen); err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
perm, err := security.RandomPermutation(tableLen)
if err != nil {
return nil, err
}
for idx := range perm {
perm[idx] += 1
}
if err = tx.Commit(); err != nil {
log.Print(err)
return nil, config.ErrSqliteTx
}
return perm[:count], nil
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"context"
"git.infra.nkode.tech/dkelly/nkode-core/entities"
"git.infra.nkode.tech/dkelly/nkode-core/sqlc"
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestNewSqliteDB(t *testing.T) {
dbPath := os.Getenv("TEST_DB")
// sql_driver.MakeTables(dbFile)
ctx := context.Background()
sqliteDb, err := sqlc.OpenSqliteDb(dbPath)
assert.NoError(t, err)
queue, err := sqlc.NewQueue(sqliteDb, ctx)
assert.NoError(t, err)
queue.Start()
defer queue.Stop()
db := NewSqliteRepository(queue, ctx)
assert.NoError(t, err)
testSignupLoginRenew(t, &db)
testSqliteDBRandomSvgInterface(t, &db)
}
func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) {
nkodePolicy := entities.NewDefaultNKodePolicy()
customerOrig, err := entities.NewCustomer(nkodePolicy)
assert.NoError(t, err)
err = db.CreateCustomer(*customerOrig)
assert.NoError(t, err)
customer, err := db.GetCustomer(customerOrig.Id)
assert.NoError(t, err)
assert.Equal(t, customerOrig, customer)
username := "test_user@example.com"
kp := entities.KeypadDefault
passcodeIdx := []int{0, 1, 2, 3}
mockSvgInterface := make(entities.SvgIdInterface, kp.TotalAttrs())
ui, err := entities.NewUserInterface(&kp, mockSvgInterface)
assert.NoError(t, err)
userOrig, err := entities.NewUser(*customer, username, passcodeIdx, *ui, kp)
assert.NoError(t, err)
err = db.WriteNewUser(*userOrig)
assert.NoError(t, err)
user, err := db.GetUser(entities.UserEmail(username), customer.Id)
assert.NoError(t, err)
assert.Equal(t, userOrig, user)
err = db.Renew(customer.Id)
assert.NoError(t, err)
}
func testSqliteDBRandomSvgInterface(t *testing.T, db CustomerUserRepository) {
kp := entities.KeypadMax
svgs, err := db.RandomSvgInterface(kp)
assert.NoError(t, err)
assert.Len(t, svgs, kp.TotalAttrs())
}