9 Commits

Author SHA1 Message Date
44bede14e4 refactor sqlite repository 2025-01-25 14:39:02 -06:00
0a1b8f9457 refactor var names 2025-01-25 14:30:06 -06:00
7f6c828ea5 don't use session 2025-01-23 14:54:05 -06:00
665341961d add delete session 2025-01-23 14:43:59 -06:00
d23671ff2b add delete session 2025-01-23 14:43:51 -06:00
c1ed5dafe3 add user login sessions 2025-01-23 14:42:34 -06:00
2eb72988e5 printf to log 2025-01-23 14:23:28 -06:00
536894c7dd fix make table order 2025-01-23 06:41:40 -06:00
3c0b3c04b7 implement cli 2025-01-23 06:33:29 -06:00
11 changed files with 146 additions and 285 deletions

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@
icons/ icons/
json_icon/ json_icon/
flaticon_colored_svgs/ flaticon_colored_svgs/
nkode

View File

@@ -20,16 +20,16 @@ const (
) )
type NKodeAPI struct { type NKodeAPI struct {
Db repository.CustomerUserRepository repo repository.CustomerUserRepository
SignupSessionCache *cache.Cache signupSessionCache *cache.Cache
EmailQueue *email.Queue emailQueue *email.Queue
} }
func NewNKodeAPI(db repository.CustomerUserRepository, queue *email.Queue) NKodeAPI { func NewNKodeAPI(repo repository.CustomerUserRepository, queue *email.Queue) NKodeAPI {
return NKodeAPI{ return NKodeAPI{
Db: db, repo: repo,
EmailQueue: queue, emailQueue: queue,
SignupSessionCache: cache.New(sessionExpiration, sessionCleanupInterval), signupSessionCache: cache.New(sessionExpiration, sessionCleanupInterval),
} }
} }
@@ -41,7 +41,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy entities.NKodePolicy, id *entit
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = n.Db.CreateCustomer(*newCustomer) err = n.repo.CreateCustomer(*newCustomer)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -50,7 +50,7 @@ func (n *NKodeAPI) CreateNewCustomer(nkodePolicy entities.NKodePolicy, id *entit
} }
func (n *NKodeAPI) GenerateSignupResetInterface(userEmail entities.UserEmail, customerId entities.CustomerId, kp entities.KeypadDimension, reset bool) (*entities.SignupResetInterface, error) { func (n *NKodeAPI) GenerateSignupResetInterface(userEmail entities.UserEmail, customerId entities.CustomerId, kp entities.KeypadDimension, reset bool) (*entities.SignupResetInterface, error) {
user, err := n.Db.GetUser(userEmail, customerId) user, err := n.repo.GetUser(userEmail, customerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +58,7 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail entities.UserEmail, cu
log.Printf("user %s already exists", string(userEmail)) log.Printf("user %s already exists", string(userEmail))
return nil, config.ErrUserAlreadyExists return nil, config.ErrUserAlreadyExists
} }
svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) svgIdxInterface, err := n.repo.RandomSvgIdxInterface(kp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -66,10 +66,10 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail entities.UserEmail, cu
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := n.SignupSessionCache.Add(signupSession.Id.String(), *signupSession, sessionExpiration); err != nil { if err := n.signupSessionCache.Add(signupSession.Id.String(), *signupSession, sessionExpiration); err != nil {
return nil, err return nil, err
} }
svgInterface, err := n.Db.GetSvgStringInterface(signupSession.LoginUserInterface.SvgId) svgInterface, err := n.repo.GetSvgStringInterface(signupSession.LoginUserInterface.SvgId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -84,12 +84,12 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail entities.UserEmail, cu
} }
func (n *NKodeAPI) SetNKode(customerId entities.CustomerId, sessionId entities.SessionId, keySelection entities.KeySelection) (entities.IdxInterface, error) { func (n *NKodeAPI) SetNKode(customerId entities.CustomerId, sessionId entities.SessionId, keySelection entities.KeySelection) (entities.IdxInterface, error) {
_, err := n.Db.GetCustomer(customerId) _, err := n.repo.GetCustomer(customerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
session, exists := n.SignupSessionCache.Get(sessionId.String()) session, exists := n.signupSessionCache.Get(sessionId.String())
if !exists { if !exists {
log.Printf("session id does not exist %s", sessionId) log.Printf("session id does not exist %s", sessionId)
return nil, config.ErrSignupSessionDNE return nil, config.ErrSignupSessionDNE
@@ -103,12 +103,12 @@ func (n *NKodeAPI) SetNKode(customerId entities.CustomerId, sessionId entities.S
if err != nil { if err != nil {
return nil, err return nil, err
} }
n.SignupSessionCache.Set(sessionId.String(), userSession, sessionExpiration) n.signupSessionCache.Set(sessionId.String(), userSession, sessionExpiration)
return confirmInterface, nil return confirmInterface, nil
} }
func (n *NKodeAPI) ConfirmNKode(customerId entities.CustomerId, sessionId entities.SessionId, keySelection entities.KeySelection) error { func (n *NKodeAPI) ConfirmNKode(customerId entities.CustomerId, sessionId entities.SessionId, keySelection entities.KeySelection) error {
session, exists := n.SignupSessionCache.Get(sessionId.String()) session, exists := n.signupSessionCache.Get(sessionId.String())
if !exists { if !exists {
log.Printf("session id does not exist %s", sessionId) log.Printf("session id does not exist %s", sessionId)
return config.ErrSignupSessionDNE return config.ErrSignupSessionDNE
@@ -118,7 +118,7 @@ func (n *NKodeAPI) ConfirmNKode(customerId entities.CustomerId, sessionId entiti
// handle the case where the type assertion fails // handle the case where the type assertion fails
return config.ErrSignupSessionDNE return config.ErrSignupSessionDNE
} }
customer, err := n.Db.GetCustomer(customerId) customer, err := n.repo.GetCustomer(customerId)
if err != nil { if err != nil {
return err return err
} }
@@ -134,16 +134,16 @@ func (n *NKodeAPI) ConfirmNKode(customerId entities.CustomerId, sessionId entiti
return err return err
} }
if userSession.Reset { if userSession.Reset {
err = n.Db.UpdateUserNKode(*user) err = n.repo.UpdateUserNKode(*user)
} else { } else {
err = n.Db.WriteNewUser(*user) err = n.repo.WriteNewUser(*user)
} }
n.SignupSessionCache.Delete(userSession.Id.String()) n.signupSessionCache.Delete(userSession.Id.String())
return err return err
} }
func (n *NKodeAPI) GetLoginInterface(userEmail entities.UserEmail, customerId entities.CustomerId) (*entities.LoginInterface, error) { func (n *NKodeAPI) GetLoginInterface(userEmail entities.UserEmail, customerId entities.CustomerId) (*entities.LoginInterface, error) {
user, err := n.Db.GetUser(userEmail, customerId) user, err := n.repo.GetUser(userEmail, customerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -151,7 +151,7 @@ func (n *NKodeAPI) GetLoginInterface(userEmail entities.UserEmail, customerId en
log.Printf("user %s for customer %s dne", userEmail, customerId) log.Printf("user %s for customer %s dne", userEmail, customerId)
return nil, config.ErrUserForCustomerDNE return nil, config.ErrUserForCustomerDNE
} }
svgInterface, err := n.Db.GetSvgStringInterface(user.Interface.SvgId) svgInterface, err := n.repo.GetSvgStringInterface(user.Interface.SvgId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -166,11 +166,11 @@ func (n *NKodeAPI) GetLoginInterface(userEmail entities.UserEmail, customerId en
} }
func (n *NKodeAPI) Login(customerId entities.CustomerId, userEmail entities.UserEmail, keySelection entities.KeySelection) (*security.AuthenticationTokens, error) { func (n *NKodeAPI) Login(customerId entities.CustomerId, userEmail entities.UserEmail, keySelection entities.KeySelection) (*security.AuthenticationTokens, error) {
customer, err := n.Db.GetCustomer(customerId) customer, err := n.repo.GetCustomer(customerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := n.Db.GetUser(userEmail, customerId) user, err := n.repo.GetUser(userEmail, customerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -184,37 +184,38 @@ func (n *NKodeAPI) Login(customerId entities.CustomerId, userEmail entities.User
} }
if user.Renew { if user.Renew {
err = n.Db.RefreshUserPasscode(*user, passcode, customer.Attributes) err = n.repo.RefreshUserPasscode(*user, passcode, customer.Attributes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
jwtToken, err := security.NewAuthenticationTokens(string(user.Email), uuid.UUID(customerId)) jwtToken, err := security.NewAuthenticationTokens(string(user.Email), uuid.UUID(customerId))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = n.Db.UpdateUserRefreshToken(user.Id, jwtToken.RefreshToken); err != nil { if err = n.repo.UpdateUserRefreshToken(user.Id, jwtToken.RefreshToken); err != nil {
return nil, err return nil, err
} }
if err = user.Interface.LoginShuffle(); err != nil { if err = user.Interface.LoginShuffle(); err != nil {
return nil, err return nil, err
} }
if err = n.Db.UpdateUserInterface(user.Id, user.Interface); err != nil { if err = n.repo.UpdateUserInterface(user.Id, user.Interface); err != nil {
return nil, err return nil, err
} }
return &jwtToken, nil return &jwtToken, nil
} }
func (n *NKodeAPI) RenewAttributes(customerId entities.CustomerId) error { func (n *NKodeAPI) RenewAttributes(customerId entities.CustomerId) error {
return n.Db.Renew(customerId) return n.repo.Renew(customerId)
} }
func (n *NKodeAPI) RandomSvgInterface() ([]string, error) { func (n *NKodeAPI) RandomSvgInterface() ([]string, error) {
return n.Db.RandomSvgInterface(entities.KeypadMax) return n.repo.RandomSvgInterface(entities.KeypadMax)
} }
func (n *NKodeAPI) RefreshToken(userEmail entities.UserEmail, customerId entities.CustomerId, refreshToken string) (string, error) { func (n *NKodeAPI) RefreshToken(userEmail entities.UserEmail, customerId entities.CustomerId, refreshToken string) (string, error) {
user, err := n.Db.GetUser(userEmail, customerId) user, err := n.repo.GetUser(userEmail, customerId)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -237,7 +238,7 @@ func (n *NKodeAPI) RefreshToken(userEmail entities.UserEmail, customerId entitie
} }
func (n *NKodeAPI) ResetNKode(userEmail entities.UserEmail, customerId entities.CustomerId) error { func (n *NKodeAPI) ResetNKode(userEmail entities.UserEmail, customerId entities.CustomerId) error {
user, err := n.Db.GetUser(userEmail, customerId) user, err := n.repo.GetUser(userEmail, customerId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user in rest nkode %v", err) return fmt.Errorf("error getting user in rest nkode %v", err)
} }
@@ -261,6 +262,6 @@ func (n *NKodeAPI) ResetNKode(userEmail entities.UserEmail, customerId entities.
Subject: "nKode Reset", Subject: "nKode Reset",
Content: htmlBody, Content: htmlBody,
} }
n.EmailQueue.AddEmail(email) n.emailQueue.AddEmail(email)
return nil return nil
} }

View File

@@ -6,7 +6,6 @@ import (
"git.infra.nkode.tech/dkelly/nkode-core/entities" "git.infra.nkode.tech/dkelly/nkode-core/entities"
"git.infra.nkode.tech/dkelly/nkode-core/repository" "git.infra.nkode.tech/dkelly/nkode-core/repository"
"git.infra.nkode.tech/dkelly/nkode-core/security" "git.infra.nkode.tech/dkelly/nkode-core/security"
"git.infra.nkode.tech/dkelly/nkode-core/sqlc"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"log" "log"
"os" "os"
@@ -19,26 +18,18 @@ func TestNKodeAPI(t *testing.T) {
dbPath := os.Getenv("TEST_DB") dbPath := os.Getenv("TEST_DB")
ctx := context.Background() ctx := context.Background()
sqliteDb, err := sqlc.OpenSqliteDb(dbPath) sqlitedb, err := repository.NewSqliteRepository(ctx, dbPath)
assert.NoError(t, err) if err != nil {
queue, err := sqlc.NewQueue(sqliteDb, ctx)
assert.NoError(t, err)
queue.Start()
defer func(queue *sqlc.Queue) {
if err := queue.Stop(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
}(queue) sqlitedb.Start()
sqlitedb := repository.NewSqliteRepository(queue, ctx) defer func(sqldb *repository.SqliteRepository) {
testNKodeAPI(t, &sqlitedb) if err := sqldb.Stop(); err != nil {
log.Fatal(err)
}
}(sqlitedb)
testNKodeAPI(t, sqlitedb)
//if _, err := os.Stat(dbPath); err == nil {
// err = os.Remove(dbPath)
// assert.NoError(t, err)
//} else {
// assert.NoError(t, err)
//}
} }
func testNKodeAPI(t *testing.T, db repository.CustomerUserRepository) { func testNKodeAPI(t *testing.T, db repository.CustomerUserRepository) {

View File

@@ -1,258 +1,86 @@
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" _ "embed"
"flag"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" // Import the SQLite3 driver "git.infra.nkode.tech/dkelly/nkode-core/repository"
"io/ioutil" "git.infra.nkode.tech/dkelly/nkode-core/sqlite"
_ "github.com/mattn/go-sqlite3"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"strings"
) )
type Icon struct {
Body string `json:"body"`
Width *int `json:"width,omitempty"`
}
// Root Define the Root struct to represent the entire JSON structure
type Root struct {
Prefix string `json:"prefix"`
Icons map[string]Icon `json:"icons"`
Height int `json:"height"`
}
func main() { func main() {
testDbPath := os.Getenv("TEST_DB_PATH") sqliteSchema, err := sqlite.FS.ReadFile("schema.sql")
dbPath := os.Getenv("DB_PATH")
dbPaths := []string{testDbPath, dbPath}
flaticonSvgDir := os.Getenv("SVG_DIR")
//dbPath := "/Users/donov/Desktop/nkode.db"
//dbPaths := []string{dbPath}
//outputStr := MakeSvgFiles()
for _, path := range dbPaths {
MakeTables(path)
FlaticonToSqlite(path, flaticonSvgDir)
//SvgToSqlite(path, outputStr)
}
}
func FlaticonToSqlite(dbPath string, svgDir string) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer db.Close() commandLine := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
dbPath := commandLine.String("db-path", "", "Path to the database")
svgPath := commandLine.String("svg-path", "", "Path to the SVG directory")
if err = commandLine.Parse(os.Args[1:]); err != nil {
log.Fatalf("Failed to parse flags: %v", err)
}
if err = MakeTables(*dbPath, string(sqliteSchema)); err != nil {
log.Fatal(err)
}
ctx := context.Background()
sqliteRepo, err := repository.NewSqliteRepository(ctx, *dbPath)
sqliteRepo.Start()
defer func(sqliteRepo *repository.SqliteRepository) {
if err := sqliteRepo.Stop(); err != nil {
log.Fatal(err)
}
}(sqliteRepo)
// Open the directory FlaticonToSqlite(sqliteRepo, *svgPath)
log.Println(fmt.Sprintf("Successfully added all SVGs in %s to the database at %s\n", *svgPath, *dbPath))
}
func FlaticonToSqlite(repo *repository.SqliteRepository, svgDir string) {
files, err := os.ReadDir(svgDir) files, err := os.ReadDir(svgDir)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
for _, file := range files { for _, file := range files {
// Check if it is a regular file (not a directory) and has a .svg extension
if file.IsDir() || filepath.Ext(file.Name()) != ".svg" { if file.IsDir() || filepath.Ext(file.Name()) != ".svg" {
continue continue
} }
filePath := filepath.Join(svgDir, file.Name()) filePath := filepath.Join(svgDir, file.Name())
// Read the file contents
content, err := os.ReadFile(filePath) content, err := os.ReadFile(filePath)
if err != nil { if err != nil {
log.Println("Error reading file:", filePath, err) log.Println("Error reading file:", filePath, err)
continue continue
} }
// Print the file name and first few bytes of the file content if err = repo.AddSvg(string(content)); err != nil {
insertSql := `
INSERT INTO svg_icon (svg)
VALUES (?)
`
_, err = db.Exec(insertSql, string(content))
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
} }
func SvgToSqlite(dbPath string, outputStr string) { func MakeTables(dbPath string, schema string) error {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err = os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return err
}
if _, err = os.Create(dbPath); err != nil {
return err
}
}
db, err := sql.Open("sqlite3", dbPath) db, err := sql.Open("sqlite3", dbPath)
if err != nil { if err != nil {
log.Fatal(err) return err
} }
defer db.Close() if _, err = db.Exec(schema); err != nil {
return err
lines := strings.Split(outputStr, "\n")
insertSql := `
INSERT INTO svg_icon (svg)
VALUES (?)
`
for _, line := range lines {
if line == "" {
continue
}
_, err := db.Exec(insertSql, line)
if err != nil {
log.Fatal(err)
}
}
}
func MakeSvgFiles() string {
jsonFiles, err := GetAllFiles("./core/sqlite-init/json")
if err != nil {
log.Fatalf("Error getting JSON files: %v", err)
}
if len(jsonFiles) == 0 {
log.Fatal("No JSON files found in ./json folder")
}
var outputStr string
strSet := make(map[string]struct{})
for _, filename := range jsonFiles {
fileData, err := LoadJson(filename)
if err != nil {
log.Print("Error loading JSON file: ", err)
continue
}
height := fileData.Height
for name, icon := range fileData.Icons {
width := height
parts := strings.Split(name, "-")
if len(parts) <= 0 {
log.Print(name, " has no parts")
continue
}
part0 := parts[0]
_, exists := strSet[part0]
if exists {
continue
}
if icon.Width != nil {
width = *icon.Width
}
strSet[part0] = struct{}{}
if icon.Body == "" {
continue
}
outputStr = fmt.Sprintf("%s<svg viewBox=\"0 0 %d %d\" xmlns=\"http://www.w3.org/2000/svg\">%s</svg>\n", outputStr, width, height, icon.Body)
}
}
return outputStr
}
func GetAllFiles(dir string) ([]string, error) {
// Use ioutil.ReadDir to list all files in the directory
files, err := ioutil.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("unable to read directory: %v", err)
}
// Create a slice to hold the JSON filenames
var jsonFiles []string
// Loop through the files and filter out JSON files
for _, file := range files {
if !file.IsDir() && filepath.Ext(file.Name()) == ".json" {
jsonFiles = append(jsonFiles, filepath.Join(dir, file.Name()))
}
}
return jsonFiles, nil
}
func LoadJson(filename string) (*Root, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %v", filename, err)
}
var root Root
err = json.Unmarshal(data, &root)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON: %v", err)
}
return &root, nil
}
func MakeTables(dbPath string) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
log.Fatal(err)
}
defer db.Close()
createTable := `
PRAGMA journal_mode=WAL;
--PRAGMA busy_timeout = 5000; -- Wait up to 5 seconds
--PRAGMA synchronous = NORMAL; -- Reduce sync frequency for less locking
--PRAGMA cache_size = -16000; -- Increase cache size (16MB)PRAGMA
CREATE TABLE IF NOT EXISTS customer (
id TEXT NOT NULL PRIMARY KEY
,max_nkode_len INTEGER NOT NULL
,min_nkode_len INTEGER NOT NULL
,distinct_sets INTEGER NOT NULL
,distinct_attributes INTEGER NOT NULL
,lock_out INTEGER NOT NULL
,expiration INTEGER NOT NULL
,attribute_values BLOB NOT NULL
,set_values BLOB NOT NULL
,last_renew TEXT NOT NULL
,created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS user (
id TEXT NOT NULL PRIMARY KEY
,email TEXT NOT NULL
-- first_name TEXT NOT NULL
-- last_name TEXT NOT NULL
,renew INT NOT NULL
,refresh_token TEXT
,customer_id TEXT NOT NULL
-- Enciphered Passcode
,code TEXT NOT NULL
,mask TEXT NOT NULL
-- Keypad Dimensions
,attributes_per_key INT NOT NULL
,number_of_keys INT NOT NULL
-- User Keys
,alpha_key BLOB NOT NULL
,set_key BLOB NOT NULL
,pass_key BLOB NOT NULL
,mask_key BLOB NOT NULL
,salt BLOB NOT NULL
,max_nkode_len INT NOT NULL
-- User Interface
,idx_interface BLOB NOT NULL
,svg_id_interface BLOB NOT NULL
,last_login TEXT NULL
,created_at TEXT
,FOREIGN KEY (customer_id) REFERENCES customer(id)
,UNIQUE(customer_id, email)
);
CREATE TABLE IF NOT EXISTS svg_icon (
id INTEGER PRIMARY KEY AUTOINCREMENT
,svg TEXT NOT NULL
);
`
_, err = db.Exec(createTable)
if err != nil {
log.Fatal(err)
} }
return db.Close()
} }

View File

@@ -20,11 +20,27 @@ type SqliteRepository struct {
ctx context.Context ctx context.Context
} }
func NewSqliteRepository(queue *sqlc.Queue, ctx context.Context) SqliteRepository { func NewSqliteRepository(ctx context.Context, dbPath string) (*SqliteRepository, error) {
return SqliteRepository{ sqliteDb, err := sqlc.OpenSqliteDb(dbPath)
if err != nil {
return nil, err
}
queue, err := sqlc.NewQueue(sqliteDb, ctx)
if err != nil {
return nil, err
}
return &SqliteRepository{
Queue: queue, Queue: queue,
ctx: ctx, ctx: ctx,
} }, nil
}
func (d *SqliteRepository) Start() {
d.Queue.Start()
}
func (d *SqliteRepository) Stop() error {
return d.Queue.Stop()
} }
func (d *SqliteRepository) CreateCustomer(c entities.Customer) error { func (d *SqliteRepository) CreateCustomer(c entities.Customer) error {
@@ -264,6 +280,17 @@ func (d *SqliteRepository) RefreshUserPasscode(user entities.User, passcodeIdx [
return d.Queue.EnqueueWriteTx(queryFunc, params) return d.Queue.EnqueueWriteTx(queryFunc, params)
} }
func (d *SqliteRepository) AddSvg(svg string) error {
queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error {
params, ok := args.(string)
if !ok {
return fmt.Errorf("invalid argument type: expected AddSvg")
}
return q.AddSvg(ctx, params)
}
return d.Queue.EnqueueWriteTx(queryFunc, svg)
}
func (d *SqliteRepository) GetCustomer(id entities.CustomerId) (*entities.Customer, error) { func (d *SqliteRepository) GetCustomer(id entities.CustomerId) (*entities.Customer, error) {
customer, err := d.Queue.Queries.GetCustomer(d.ctx, uuid.UUID(id).String()) customer, err := d.Queue.Queries.GetCustomer(d.ctx, uuid.UUID(id).String())
if err != nil { if err != nil {

View File

@@ -3,7 +3,6 @@ package repository
import ( import (
"context" "context"
"git.infra.nkode.tech/dkelly/nkode-core/entities" "git.infra.nkode.tech/dkelly/nkode-core/entities"
"git.infra.nkode.tech/dkelly/nkode-core/sqlc"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"os" "os"
"testing" "testing"
@@ -11,20 +10,16 @@ import (
func TestNewSqliteDB(t *testing.T) { func TestNewSqliteDB(t *testing.T) {
dbPath := os.Getenv("TEST_DB") dbPath := os.Getenv("TEST_DB")
// sql_driver.MakeTables(dbFile)
ctx := context.Background() ctx := context.Background()
sqliteDb, err := sqlc.OpenSqliteDb(dbPath) sqliteDb, err := NewSqliteRepository(ctx, dbPath)
assert.NoError(t, err) assert.NoError(t, err)
sqliteDb.Start()
queue, err := sqlc.NewQueue(sqliteDb, ctx) defer func(t *testing.T, sqliteDb *SqliteRepository) {
err := sqliteDb.Stop()
assert.NoError(t, err) assert.NoError(t, err)
}(t, sqliteDb)
queue.Start() testSignupLoginRenew(t, sqliteDb)
defer queue.Stop() testSqliteDBRandomSvgInterface(t, sqliteDb)
db := NewSqliteRepository(queue, ctx)
assert.NoError(t, err)
testSignupLoginRenew(t, &db)
testSqliteDBRandomSvgInterface(t, &db)
} }
func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) {

View File

@@ -6,4 +6,4 @@ sql:
gen: gen:
go: go:
package: "sqlc" package: "sqlc"
out: "./pkg/nkode-core/sqlc" out: "./sqlc"

View File

@@ -10,6 +10,15 @@ import (
"database/sql" "database/sql"
) )
const addSvg = `-- name: AddSvg :exec
INSERT INTO svg_icon (svg) VALUES (?)
`
func (q *Queries) AddSvg(ctx context.Context, svg string) error {
_, err := q.db.ExecContext(ctx, addSvg, svg)
return err
}
const createCustomer = `-- name: CreateCustomer :exec const createCustomer = `-- name: CreateCustomer :exec
INSERT INTO customer ( INSERT INTO customer (
id id

View File

@@ -10,11 +10,11 @@ import (
const writeBufferSize = 100 const writeBufferSize = 100
type SqlcGeneric func(*Queries, context.Context, any) error type GenericQuery func(*Queries, context.Context, any) error
type WriteTx struct { type WriteTx struct {
ErrChan chan error ErrChan chan error
Query SqlcGeneric Query GenericQuery
Args interface{} Args interface{}
} }
@@ -63,7 +63,7 @@ func (d *Queue) Stop() error {
return d.Db.Close() return d.Db.Close()
} }
func (d *Queue) EnqueueWriteTx(queryFunc SqlcGeneric, args any) error { func (d *Queue) EnqueueWriteTx(queryFunc GenericQuery, args any) error {
select { select {
case <-d.ctx.Done(): case <-d.ctx.Done():
return errors.New("database is shutting down") return errors.New("database is shutting down")

6
sqlite/embed.go Normal file
View File

@@ -0,0 +1,6 @@
package sqlite
import "embed"
//go:embed schema.sql
var FS embed.FS

View File

@@ -37,6 +37,9 @@ INSERT INTO user (
) )
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
-- name: AddSvg :exec
INSERT INTO svg_icon (svg) VALUES (?);
-- name: UpdateUser :exec -- name: UpdateUser :exec
UPDATE user UPDATE user
SET renew = ? SET renew = ?