more refactoring
This commit is contained in:
21
internal/db/customer_user_repository.go
Normal file
21
internal/db/customer_user_repository.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"go-nkode/internal/entities"
|
||||
"go-nkode/internal/models"
|
||||
)
|
||||
|
||||
type CustomerUserRepository interface {
|
||||
GetCustomer(models.CustomerId) (*entities.Customer, error)
|
||||
GetUser(models.UserEmail, models.CustomerId) (*entities.User, error)
|
||||
WriteNewCustomer(entities.Customer) error
|
||||
WriteNewUser(entities.User) error
|
||||
UpdateUserNKode(entities.User) error
|
||||
UpdateUserInterface(models.UserId, entities.UserInterface) error
|
||||
UpdateUserRefreshToken(models.UserId, string) error
|
||||
Renew(models.CustomerId) error
|
||||
RefreshUserPasscode(entities.User, []int, entities.CustomerAttributes) error
|
||||
RandomSvgInterface(entities.KeypadDimension) ([]string, error)
|
||||
RandomSvgIdxInterface(entities.KeypadDimension) (models.SvgIdInterface, error)
|
||||
GetSvgStringInterface(models.SvgIdInterface) ([]string, error)
|
||||
}
|
||||
@@ -3,24 +3,25 @@ package db
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go-nkode/internal/entities"
|
||||
"go-nkode/internal/models"
|
||||
)
|
||||
|
||||
type InMemoryDb struct {
|
||||
Customers map[models.CustomerId]models.Customer
|
||||
Users map[models.UserId]models.User
|
||||
Customers map[models.CustomerId]entities.Customer
|
||||
Users map[models.UserId]entities.User
|
||||
userIdMap map[string]models.UserId
|
||||
}
|
||||
|
||||
func NewInMemoryDb() InMemoryDb {
|
||||
return InMemoryDb{
|
||||
Customers: make(map[models.CustomerId]models.Customer),
|
||||
Users: make(map[models.UserId]models.User),
|
||||
Customers: make(map[models.CustomerId]entities.Customer),
|
||||
Users: make(map[models.UserId]entities.User),
|
||||
userIdMap: make(map[string]models.UserId),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*models.Customer, error) {
|
||||
func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
|
||||
customer, exists := db.Customers[id]
|
||||
if !exists {
|
||||
return nil, errors.New(fmt.Sprintf("customer %s dne", customer.Id))
|
||||
@@ -28,7 +29,7 @@ func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*models.Customer, error
|
||||
return &customer, nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerId) (*models.User, error) {
|
||||
func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
|
||||
key := userIdKey(customerId, username)
|
||||
userId, exists := db.userIdMap[key]
|
||||
if !exists {
|
||||
@@ -41,7 +42,7 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) WriteNewCustomer(customer models.Customer) error {
|
||||
func (db *InMemoryDb) WriteNewCustomer(customer entities.Customer) error {
|
||||
_, exists := db.Customers[customer.Id]
|
||||
|
||||
if exists {
|
||||
@@ -51,7 +52,7 @@ func (db *InMemoryDb) WriteNewCustomer(customer models.Customer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) WriteNewUser(user models.User) error {
|
||||
func (db *InMemoryDb) WriteNewUser(user entities.User) error {
|
||||
_, exists := db.Customers[user.CustomerId]
|
||||
if !exists {
|
||||
return errors.New(fmt.Sprintf("can't add user %s to customer %s: customer dne", user.Email, user.CustomerId))
|
||||
@@ -67,11 +68,11 @@ func (db *InMemoryDb) WriteNewUser(user models.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) UpdateUserNKode(user models.User) error {
|
||||
func (db *InMemoryDb) UpdateUserNKode(user entities.User) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) UpdateUserInterface(userId models.UserId, ui models.UserInterface) error {
|
||||
func (db *InMemoryDb) UpdateUserInterface(userId models.UserId, ui entities.UserInterface) error {
|
||||
user, exists := db.Users[userId]
|
||||
if !exists {
|
||||
return errors.New(fmt.Sprintf("can't update user %s, dne", user.Id))
|
||||
@@ -107,7 +108,7 @@ func (db *InMemoryDb) Renew(id models.CustomerId) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) RefreshUserPasscode(user models.User, passocode []int, customerAttr models.CustomerAttributes) error {
|
||||
func (db *InMemoryDb) RefreshUserPasscode(user entities.User, passocode []int, customerAttr entities.CustomerAttributes) error {
|
||||
err := user.RefreshPasscode(passocode, customerAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -116,11 +117,11 @@ func (db *InMemoryDb) RefreshUserPasscode(user models.User, passocode []int, cus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) {
|
||||
func (db *InMemoryDb) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) {
|
||||
return make([]string, kp.TotalAttrs()), nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDb) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) {
|
||||
func (db *InMemoryDb) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) {
|
||||
svgs := make(models.SvgIdInterface, kp.TotalAttrs())
|
||||
for idx := range svgs {
|
||||
svgs[idx] = idx
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver
|
||||
"go-nkode/config"
|
||||
"go-nkode/internal/entities"
|
||||
"go-nkode/internal/models"
|
||||
"go-nkode/internal/security"
|
||||
"log"
|
||||
@@ -60,7 +61,7 @@ func (d *SqliteDB) CloseDb() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewCustomer(c models.Customer) error {
|
||||
func (d *SqliteDB) WriteNewCustomer(c entities.Customer) error {
|
||||
query := `
|
||||
INSERT INTO customer (
|
||||
id
|
||||
@@ -85,7 +86,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?)
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) WriteNewUser(u models.User) error {
|
||||
func (d *SqliteDB) WriteNewUser(u entities.User) error {
|
||||
query := `
|
||||
INSERT INTO user (
|
||||
id
|
||||
@@ -128,7 +129,7 @@ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserNKode(u models.User) error {
|
||||
func (d *SqliteDB) UpdateUserNKode(u entities.User) error {
|
||||
query := `
|
||||
UPDATE user
|
||||
SET renew = ?
|
||||
@@ -158,7 +159,7 @@ WHERE email = ? AND customer_id = ?
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui models.UserInterface) error {
|
||||
func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error {
|
||||
query := `
|
||||
UPDATE user SET idx_interface = ?, last_login = ? WHERE id = ?
|
||||
`
|
||||
@@ -219,20 +220,20 @@ WHERE customer_id = ?
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := models.User{
|
||||
user := entities.User{
|
||||
Id: models.UserId{},
|
||||
CustomerId: models.CustomerId{},
|
||||
Email: "",
|
||||
EncipheredPasscode: models.EncipheredNKode{},
|
||||
Kp: models.KeypadDimension{
|
||||
Kp: entities.KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: models.UserCipherKeys{
|
||||
CipherKeys: entities.UserCipherKeys{
|
||||
AlphaKey: security.ByteArrToUint64Arr(alphaBytes),
|
||||
SetKey: security.ByteArrToUint64Arr(setBytes),
|
||||
},
|
||||
Interface: models.UserInterface{},
|
||||
Interface: entities.UserInterface{},
|
||||
Renew: false,
|
||||
}
|
||||
err = user.RenewKeys(setXor, attrXor)
|
||||
@@ -255,7 +256,7 @@ WHERE id = ?;
|
||||
return d.addWriteTx(renewQuery, renewArgs)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RefreshUserPasscode(user models.User, passcodeIdx []int, customerAttr models.CustomerAttributes) error {
|
||||
func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error {
|
||||
err := user.RefreshPasscode(passcodeIdx, customerAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -276,7 +277,7 @@ WHERE id = ?;
|
||||
args := []any{user.RefreshToken, 0, user.EncipheredPasscode.Code, user.EncipheredPasscode.Mask, security.Uint64ArrToByteArr(user.CipherKeys.AlphaKey), security.Uint64ArrToByteArr(user.CipherKeys.SetKey), security.Uint64ArrToByteArr(user.CipherKeys.PassKey), security.Uint64ArrToByteArr(user.CipherKeys.MaskKey), user.CipherKeys.Salt, uuid.UUID(user.Id).String()}
|
||||
return d.addWriteTx(query, args)
|
||||
}
|
||||
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*models.Customer, error) {
|
||||
func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -324,7 +325,7 @@ WHERE id = ?
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customer := models.Customer{
|
||||
customer := entities.Customer{
|
||||
Id: id,
|
||||
NKodePolicy: models.NKodePolicy{
|
||||
MaxNkodeLen: maxNKodeLen,
|
||||
@@ -334,7 +335,7 @@ WHERE id = ?
|
||||
LockOut: lockOut,
|
||||
Expiration: expiration,
|
||||
},
|
||||
Attributes: models.NewCustomerAttributesFromBytes(attributeValues, setValues),
|
||||
Attributes: entities.NewCustomerAttributesFromBytes(attributeValues, setValues),
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
@@ -342,7 +343,7 @@ WHERE id = ?
|
||||
return &customer, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*models.User, error) {
|
||||
func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -401,7 +402,7 @@ WHERE user.email = ? AND user.customer_id = ?
|
||||
renew = true
|
||||
}
|
||||
|
||||
user := models.User{
|
||||
user := entities.User{
|
||||
Id: models.UserId(userId),
|
||||
CustomerId: customerId,
|
||||
Email: email,
|
||||
@@ -409,11 +410,11 @@ WHERE user.email = ? AND user.customer_id = ?
|
||||
Code: code,
|
||||
Mask: mask,
|
||||
},
|
||||
Kp: models.KeypadDimension{
|
||||
Kp: entities.KeypadDimension{
|
||||
AttrsPerKey: attrsPerKey,
|
||||
NumbOfKeys: numbOfKeys,
|
||||
},
|
||||
CipherKeys: models.UserCipherKeys{
|
||||
CipherKeys: entities.UserCipherKeys{
|
||||
AlphaKey: security.ByteArrToUint64Arr(alphaKey),
|
||||
SetKey: security.ByteArrToUint64Arr(setKey),
|
||||
PassKey: security.ByteArrToUint64Arr(passKey),
|
||||
@@ -422,7 +423,7 @@ WHERE user.email = ? AND user.customer_id = ?
|
||||
MaxNKodeLen: maxNKodeLen,
|
||||
Kp: nil,
|
||||
},
|
||||
Interface: models.UserInterface{
|
||||
Interface: entities.UserInterface{
|
||||
IdxInterface: security.ByteArrToIntArr(idxInterface),
|
||||
SvgId: security.ByteArrToIntArr(svgIdInterface),
|
||||
Kp: nil,
|
||||
@@ -438,7 +439,7 @@ WHERE user.email = ? AND user.customer_id = ?
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, error) {
|
||||
func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) {
|
||||
ids, err := d.getRandomIds(kp.TotalAttrs())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -446,7 +447,7 @@ func (d *SqliteDB) RandomSvgInterface(kp models.KeypadDimension) ([]string, erro
|
||||
return d.getSvgsById(ids)
|
||||
}
|
||||
|
||||
func (d *SqliteDB) RandomSvgIdxInterface(kp models.KeypadDimension) (models.SvgIdInterface, error) {
|
||||
func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) {
|
||||
return d.getRandomIds(kp.TotalAttrs())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package db
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go-nkode/internal/api"
|
||||
"go-nkode/internal/entities"
|
||||
"go-nkode/internal/models"
|
||||
"os"
|
||||
"testing"
|
||||
@@ -24,9 +24,9 @@ func TestNewSqliteDB(t *testing.T) {
|
||||
// }
|
||||
}
|
||||
|
||||
func testSignupLoginRenew(t *testing.T, db api.DbAccessor) {
|
||||
func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) {
|
||||
nkodePolicy := models.NewDefaultNKodePolicy()
|
||||
customerOrig, err := models.NewCustomer(nkodePolicy)
|
||||
customerOrig, err := entities.NewCustomer(nkodePolicy)
|
||||
assert.NoError(t, err)
|
||||
err = db.WriteNewCustomer(*customerOrig)
|
||||
assert.NoError(t, err)
|
||||
@@ -34,12 +34,12 @@ func testSignupLoginRenew(t *testing.T, db api.DbAccessor) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, customerOrig, customer)
|
||||
username := "test_user@example.com"
|
||||
kp := models.KeypadDefault
|
||||
kp := entities.KeypadDefault
|
||||
passcodeIdx := []int{0, 1, 2, 3}
|
||||
mockSvgInterface := make(models.SvgIdInterface, kp.TotalAttrs())
|
||||
ui, err := models.NewUserInterface(&kp, mockSvgInterface)
|
||||
ui, err := entities.NewUserInterface(&kp, mockSvgInterface)
|
||||
assert.NoError(t, err)
|
||||
userOrig, err := models.NewUser(*customer, username, passcodeIdx, *ui, kp)
|
||||
userOrig, err := entities.NewUser(*customer, username, passcodeIdx, *ui, kp)
|
||||
assert.NoError(t, err)
|
||||
err = db.WriteNewUser(*userOrig)
|
||||
assert.NoError(t, err)
|
||||
@@ -52,8 +52,8 @@ func testSignupLoginRenew(t *testing.T, db api.DbAccessor) {
|
||||
|
||||
}
|
||||
|
||||
func testSqliteDBRandomSvgInterface(t *testing.T, db api.DbAccessor) {
|
||||
kp := models.KeypadMax
|
||||
func testSqliteDBRandomSvgInterface(t *testing.T, db CustomerUserRepository) {
|
||||
kp := entities.KeypadMax
|
||||
svgs, err := db.RandomSvgInterface(kp)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, svgs, kp.TotalAttrs())
|
||||
|
||||
Reference in New Issue
Block a user