From 8d4c8f71b0bf8812e6cf3dc8d4a922785a7fc168 Mon Sep 17 00:00:00 2001 From: Donovan Date: Fri, 1 Aug 2025 10:49:46 -0500 Subject: [PATCH] rigged shuffle --- .air.toml | 6 +- Taskfile.yaml | 9 + cmd/cli/main.go | 273 ++++++++++++++++++ cmd/{ => restapi}/main.go | 16 +- cmd/{ => restapi}/main_test.go | 0 internal/api/handler.go | 19 +- internal/api/nkode_api.go | 77 +++-- internal/api/nkode_api_test.go | 4 +- internal/entities/customer.go | 6 +- internal/entities/user.go | 4 +- internal/entities/user_signup_session.go | 29 +- internal/models/models.go | 12 +- .../repository/customer_user_repository.go | 7 +- internal/repository/in_memory_db.go | 20 +- internal/repository/sqlite_repository.go | 124 ++++---- internal/repository/sqlite_repository_test.go | 10 +- internal/sqlc/query.sql.go | 13 + internal/sqlc/sqlite_queue.go | 36 ++- scripts/bash/rebuild_db.sh | 48 +++ sqlite/embed.go | 6 + sqlite/query.sql | 5 + 21 files changed, 578 insertions(+), 146 deletions(-) create mode 100644 cmd/cli/main.go rename cmd/{ => restapi}/main.go (92%) rename cmd/{ => restapi}/main_test.go (100%) create mode 100755 scripts/bash/rebuild_db.sh create mode 100644 sqlite/embed.go diff --git a/.air.toml b/.air.toml index 58fff2a..f81c906 100644 --- a/.air.toml +++ b/.air.toml @@ -1,11 +1,11 @@ root = "." testdata_dir = "testdata" -tmp_dir = "tmp" +tmp_dir = "bin" [build] args_bin = [] - bin = "./tmp/main" - cmd = "go build -o ./tmp/main ." + bin = "./bin/restapi" + cmd = "go build -o ./bin/restapi ./cmd/restapi" delay = 1000 exclude_dir = ["assets", "tmp", "vendor", "testdata"] exclude_file = [] diff --git a/Taskfile.yaml b/Taskfile.yaml index fef6d45..6f31d69 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -4,8 +4,17 @@ vars: compose_file: "./compose/local-compose.yaml" cache_bust: sh: "date +%s" + test_db: "~/databases/demo.db" + schema_db: "./sqlite/schema.sql" + svg_path: "~/svgs/flaticon_colored_svgs" + session_secret: "c3fca773c8889eb3352745c4fe503df0" + frontend: "http://localhost:8080" tasks: + demo_run: + cmds: + - sh -c "SQLITE_DB={{.test_db}} JWT_SECRET={{.session_secret}} FRONTEND_HOST={{.frontend}} air" + build: cmds: - docker compose -f {{.compose_file}} build --no-cache diff --git a/cmd/cli/main.go b/cmd/cli/main.go new file mode 100644 index 0000000..56cddf4 --- /dev/null +++ b/cmd/cli/main.go @@ -0,0 +1,273 @@ +package main + +import ( + "context" + "database/sql" + _ "embed" + "flag" + "fmt" + "go-nkode/internal/entities" + "go-nkode/internal/models" + "go-nkode/internal/repository" + sqlite_queue "go-nkode/internal/sqlc" + "go-nkode/sqlite" + "log" + "os" + "path/filepath" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +func main() { + if len(os.Args) < 2 { + log.Fatal("Please provide a command: build-db") + } + switch os.Args[1] { + case "build-db": + BuildDB() + case "create-customer": + CreateCustomer() + case "add-user": + AddUser() + default: + log.Fatalf("Unknown command: %s", os.Args[1]) + } +} + +func CreateCustomer() { + cliCmd := flag.NewFlagSet("create-customer", flag.ExitOnError) + customerIDStr := cliCmd.String("customer-id", "", "Customer UUID") + dbPath := cliCmd.String("db-path", "", "Path to sqlite database") + if err := cliCmd.Parse(os.Args[2:]); err != nil { + log.Fatalf("Failed to parse flags: %v", err) + } + customerID, err := models.CustomerIDFromString(*customerIDStr) + if err != nil { + log.Fatalf("Failed to parse flags: %v", err) + } + ctx := context.Background() + sqliteDb, err := sqlite_queue.OpenSqliteDb(*dbPath) + queue, err := sqlite_queue.NewQueue(ctx, sqliteDb) + queue.Start() + defer queue.Stop() + sqliteRepo := repository.NewSqliteRepository(ctx, queue) + if err != nil { + log.Fatal("error starting sqlite repo: ", err) + } + nkodePolicy := models.NewDefaultNKodePolicy() + customer, err := entities.NewCustomer(nkodePolicy) + customer.ID = customerID + if err != nil { + log.Fatal(err) + } + if err = sqliteRepo.CreateCustomer(*customer); err != nil { + log.Fatal(err) + } +} + +func AddUser() { + cliCmd := flag.NewFlagSet("add-user", flag.ExitOnError) + imgPath := cliCmd.String("img-path", "", "Path to directory with image files to add to database. The total must number must equal attrs-per-key X numb-of-keys") + imgType := cliCmd.String("img-type", "webp", "Image types webp, svg, png, jpeg, default webp") + customerIDStr := cliCmd.String("customer-id", "", "Customer ID") + dbPath := cliCmd.String("db-path", "", "Path to the database") + userEmailStr := cliCmd.String("user-email", "", "User email") + attrsPerKey := cliCmd.Int("attrs-per-key", -1, "Attributes per key") + numbOfKeys := cliCmd.Int("numb-of-keys", -1, "Number of keys") + nkodeIcons := cliCmd.String("nkode-icons", "", "common separated file names of the users nKode icons with no space. filename order sets the nkode passcode order") + if err := cliCmd.Parse(os.Args[2:]); err != nil { + log.Fatalf("Failed to parse flags: %v", err) + } + fmt.Println("os args: ", os.Args) + ctx := context.Background() + sqliteDb, err := sqlite_queue.OpenSqliteDb(*dbPath) + queue, err := sqlite_queue.NewQueue(ctx, sqliteDb) + queue.Start() + defer queue.Stop() + sqliteRepo := repository.NewSqliteRepository(ctx, queue) + if err != nil { + log.Fatal("error starting sqlite repo: ", err) + } + customer, err := validCustomerID(*customerIDStr, &sqliteRepo) + if err != nil { + log.Println("db path: ", *dbPath) + log.Fatal("invalid customer id: ", err) + } + validateUserEmail(*userEmailStr, customer.ID, &sqliteRepo) + kp := entities.KeypadDimension{ + AttrsPerKey: *attrsPerKey, + NumbOfKeys: *numbOfKeys, + } + if *attrsPerKey < entities.KeypadMin.AttrsPerKey || entities.KeypadMax.AttrsPerKey < *attrsPerKey { + log.Fatalf("invalid attributes per key valid range is %d-%d", entities.KeypadMin.AttrsPerKey, entities.KeypadMax.AttrsPerKey) + } + if *numbOfKeys < entities.KeypadMin.NumbOfKeys || entities.KeypadMax.NumbOfKeys < *numbOfKeys { + log.Fatalf("invalid number of keys. valid range is %d-%d", entities.KeypadMin.NumbOfKeys, entities.KeypadMax.NumbOfKeys) + } + if kp.IsDispersable() { + log.Fatal("Keypad can't be dispersable") + } + imgs := getImgs(*imgPath, *imgType) + if len(imgs) != kp.TotalAttrs() { + log.Fatal("svgs in directory not equal to keypad size") + } + imgIDs := make([]int, len(imgs)) + for idx, img := range imgs { + id, err := sqliteRepo.AddSVGIcon(img) + if err != nil { + log.Fatal(err) + } + imgIDs[idx] = int(id) + } + iconNames := strings.Split(*nkodeIcons, ",") + passcodeIdxs := getPasscodeSvgIdx(iconNames, *imgPath) + if err = customer.IsValidNKode(kp, passcodeIdxs); err != nil { + log.Fatal("invalid nkode: ", err) + } + userInterface, err := entities.NewUserInterface(&kp, models.SvgIdInterface(imgIDs)) + if err != nil { + log.Fatal("error creating user interface: ", err) + } + user, err := entities.NewUser(customer, *userEmailStr, passcodeIdxs, *userInterface, kp) + if err != nil { + log.Fatal("error creating user: ", err) + } + if err = sqliteRepo.WriteNewUser(*user); err != nil { + log.Fatal("error storing user: ", err) + } +} + +func getPasscodeSvgIdx(nkodeSvgFileNames []string, svgDir string) []int { + files, err := os.ReadDir(svgDir) + if err != nil { + log.Fatal(err) + } + fileNames := make([]string, 0) + for _, file := range files { + if file.IsDir() || filepath.Ext(file.Name()) != ".svg" { + continue + } + fileNames = append(fileNames, file.Name()) + } + passcode := make([]int, 0) + for _, fileName := range nkodeSvgFileNames { + idx := indexOf(fileNames, fileName) + if idx == -1 { + log.Fatal("file does not exist in svg dir: ", fileName) + } + passcode = append(passcode, idx) + } + return passcode +} + +func indexOf(slice []string, value string) int { + for i, v := range slice { + if v == value { + return i + } + } + return -1 // not found +} + +func getImgs(imgDir, imgType string) []string { + files, err := os.ReadDir(imgDir) + if err != nil { + log.Fatalf("error opening dir: %s with err: %v", imgDir, err) + } + imgs := make([]string, 0) + for _, file := range files { + if file.IsDir() || filepath.Ext(file.Name()) != "."+imgType { + continue + } + filePath := filepath.Join(imgDir, file.Name()) + content, err := os.ReadFile(filePath) + if err != nil { + log.Println("Error reading file:", filePath, err) + continue + } + imgs = append(imgs, string(content)) + } + return imgs +} + +func validCustomerID(id string, repo *repository.SqliteRepository) (entities.Customer, error) { + cID, err := models.CustomerIDFromString(id) + if err != nil { + return entities.Customer{}, err + } + customer, err := repo.GetCustomer(cID) + if err != nil { + return entities.Customer{}, err + } + return *customer, nil +} + +func validateUserEmail(email string, customerID models.CustomerID, repo *repository.SqliteRepository) { + userEmail, err := models.ParseEmail(email) + if err != nil { + log.Fatal("user email error: ", err) + } + _, err = repo.GetUser(userEmail, customerID) + if err == nil { + log.Fatal("user already exists") + } + fmt.Println("user email is valid:", userEmail) +} + +func BuildDB() { + sqliteSchema, err := sqlite.FS.ReadFile("schema.sql") + if err != nil { + log.Fatal(err) + } + cliCmd := flag.NewFlagSet("build-db", flag.ExitOnError) + dbPath := cliCmd.String("db-path", "", "Path to the database") + imgPath := cliCmd.String("img-path", "", "Path to the directory with images") + imgType := cliCmd.String("img-type", "svg", "Image types webp, svg, png, jpeg, default webp") + if err = cliCmd.Parse(os.Args[2:]); err != nil { + log.Fatalf("Failed to parse flags: %v", err) + } + if err = MakeTables(*dbPath, string(sqliteSchema)); err != nil { + log.Fatal(err) + } + ctx := context.Background() + sqliteDb, err := sqlite_queue.OpenSqliteDb(*dbPath) + queue, err := sqlite_queue.NewQueue(ctx, sqliteDb) + queue.Start() + defer queue.Stop() + sqliteRepo := repository.NewSqliteRepository(ctx, queue) + if err != nil { + log.Fatal("error starting sqlite repo: ", err) + } + numbImgs := IconImgToSqlite(&sqliteRepo, *imgPath, *imgType) + log.Printf("Successfully added %d Images in %s to the database at %s\n", numbImgs, *imgPath, *dbPath) +} + +func IconImgToSqlite(repo *repository.SqliteRepository, imgDir, imgType string) int { + imgs := getImgs(imgDir, imgType) + for _, img := range imgs { + if _, err := repo.AddSVGIcon(img); err != nil { + log.Fatal(err) + } + } + return len(imgs) +} + +func MakeTables(dbPath string, schema string) error { + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + if err = os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { + return err + } + if _, err = os.Create(dbPath); err != nil { + return err + } + } + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return err + } + if _, err = db.Exec(schema); err != nil { + return err + } + return db.Close() +} diff --git a/cmd/main.go b/cmd/restapi/main.go similarity index 92% rename from cmd/main.go rename to cmd/restapi/main.go index 2f2f19c..d06bdac 100644 --- a/cmd/main.go +++ b/cmd/restapi/main.go @@ -56,7 +56,7 @@ func main() { } ctx := context.Background() - queue, err := sqliteQueue.NewQueue(sqliteDb, ctx) + queue, err := sqliteQueue.NewQueue(ctx, sqliteDb) if err != nil { log.Fatal(err) } @@ -67,18 +67,14 @@ func main() { log.Fatal(err) } }(queue) - sesClient := email.NewSESClient() emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) emailQueue.Start() defer emailQueue.Stop() - - sqlitedb := repository.NewSqliteRepository(queue, ctx) + sqlitedb := repository.NewSqliteRepository(ctx, queue) nkodeApi := api.NewNKodeAPI(&sqlitedb, emailQueue) - AddDefaultCustomer(nkodeApi) handler := api.NKodeHandler{Api: nkodeApi} - mux := http.NewServeMux() mux.Handle(api.CreateNewCustomer, &handler) mux.Handle(api.GenerateSignupResetInterface, &handler) @@ -90,12 +86,10 @@ func main() { mux.Handle(api.RandomSvgInterface, &handler) mux.Handle(api.RefreshToken, &handler) mux.Handle(api.ResetNKode, &handler) - // Serve Swagger UI mux.Handle("/swagger/", httpSwagger.WrapHandler) - - fmt.Println("Running on localhost:8080...") - log.Fatal(http.ListenAndServe(":8080", corsMiddleware(mux))) + fmt.Println("Running on localhost:8090...") + log.Fatal(http.ListenAndServe(":8090", corsMiddleware(mux))) } func corsMiddleware(next http.Handler) http.Handler { @@ -121,7 +115,7 @@ func AddDefaultCustomer(nkodeApi api.NKodeAPI) { if err != nil { log.Fatal(err) } - customerId := models.CustomerId(newId) + customerId := models.CustomerID(newId) nkodePolicy := models.NewDefaultNKodePolicy() _, err = nkodeApi.CreateNewCustomer(nkodePolicy, &customerId) if err != nil { diff --git a/cmd/main_test.go b/cmd/restapi/main_test.go similarity index 100% rename from cmd/main_test.go rename to cmd/restapi/main_test.go diff --git a/internal/api/handler.go b/internal/api/handler.go index f7dac38..3b07ed1 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -103,12 +103,10 @@ func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter methodNotAllowed(w) return } - var signupResetPost models.GenerateSignupRestInterfacePost if err := decodeJson(w, r, &signupResetPost); err != nil { return } - kp := entities.KeypadDimension{ AttrsPerKey: signupResetPost.AttrsPerKey, NumbOfKeys: signupResetPost.NumbOfKeys, @@ -127,12 +125,11 @@ func (h *NKodeHandler) GenerateSignupResetInterfaceHandler(w http.ResponseWriter badRequest(w, malformedUserEmail) return } - resp, err := h.Api.GenerateSignupResetInterface(userEmail, models.CustomerId(customerId), kp, signupResetPost.Reset) + resp, err := h.Api.GenerateSignupResetInterface(userEmail, models.CustomerID(customerId), kp, signupResetPost.Reset) if err != nil { handleError(w, err) return } - marshalAndWriteBytes(w, resp) } @@ -156,7 +153,7 @@ func (h *NKodeHandler) SetNKodeHandler(w http.ResponseWriter, r *http.Request) { badRequest(w, malformedSessionId) return } - confirmInterface, err := h.Api.SetNKode(models.CustomerId(customerId), models.SessionId(sessionId), setNKodePost.KeySelection) + confirmInterface, err := h.Api.SetNKode(models.CustomerID(customerId), models.SessionId(sessionId), setNKodePost.KeySelection) if err != nil { handleError(w, err) return @@ -186,7 +183,7 @@ func (h *NKodeHandler) ConfirmNKodeHandler(w http.ResponseWriter, r *http.Reques badRequest(w, malformedSessionId) return } - if err = h.Api.ConfirmNKode(models.CustomerId(customerId), models.SessionId(sessionId), confirmNKodePost.KeySelection); err != nil { + if err = h.Api.ConfirmNKode(models.CustomerID(customerId), models.SessionId(sessionId), confirmNKodePost.KeySelection); err != nil { handleError(w, err) return } @@ -212,7 +209,7 @@ func (h *NKodeHandler) GetLoginInterfaceHandler(w http.ResponseWriter, r *http.R if err != nil { badRequest(w, malformedUserEmail) } - loginInterface, err := h.Api.GetLoginInterface(userEmail, models.CustomerId(customerId)) + loginInterface, err := h.Api.GetLoginInterface(userEmail, models.CustomerID(customerId)) if err != nil { handleError(w, err) return @@ -241,7 +238,7 @@ func (h *NKodeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { badRequest(w, malformedUserEmail) return } - jwtTokens, err := h.Api.Login(models.CustomerId(customerId), userEmail, loginPost.KeySelection) + jwtTokens, err := h.Api.Login(models.CustomerID(customerId), userEmail, loginPost.KeySelection) if err != nil { handleError(w, err) return @@ -265,7 +262,7 @@ func (h *NKodeHandler) RenewAttributesHandler(w http.ResponseWriter, r *http.Req badRequest(w, malformedCustomerId) return } - if err = h.Api.RenewAttributes(models.CustomerId(customerId)); err != nil { + if err = h.Api.RenewAttributes(models.CustomerID(customerId)); err != nil { handleError(w, err) return } @@ -314,7 +311,7 @@ func (h *NKodeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Reques log.Println(err) return } - accessToken, err := h.Api.RefreshToken(userEmail, models.CustomerId(customerId), refreshToken) + accessToken, err := h.Api.RefreshToken(userEmail, models.CustomerID(customerId), refreshToken) if err != nil { handleError(w, err) @@ -346,7 +343,7 @@ func (h *NKodeHandler) ResetNKode(w http.ResponseWriter, r *http.Request) { return } - if err = h.Api.ResetNKode(userEmail, models.CustomerId(customerId)); err != nil { + if err = h.Api.ResetNKode(userEmail, models.CustomerID(customerId)); err != nil { internalServerError(w) log.Println(err) return diff --git a/internal/api/nkode_api.go b/internal/api/nkode_api.go index 814b8c8..2daad14 100644 --- a/internal/api/nkode_api.go +++ b/internal/api/nkode_api.go @@ -34,23 +34,22 @@ func NewNKodeAPI(db repository.CustomerUserRepository, queue *email.Queue) NKode } } -func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models.CustomerId) (*models.CustomerId, error) { +func (n *NKodeAPI) CreateNewCustomer(nkodePolicy models.NKodePolicy, id *models.CustomerID) (*models.CustomerID, error) { newCustomer, err := entities.NewCustomer(nkodePolicy) if id != nil { - newCustomer.Id = *id + newCustomer.ID = *id } if err != nil { return nil, err } err = n.Db.CreateCustomer(*newCustomer) - if err != nil { return nil, err } - return &newCustomer.Id, nil + return &newCustomer.ID, nil } -func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, customerId models.CustomerId, kp entities.KeypadDimension, reset bool) (*models.GenerateSignupResetInterfaceResp, error) { +func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, customerId models.CustomerID, kp entities.KeypadDimension, reset bool) (*models.GenerateSignupResetInterfaceResp, error) { user, err := n.Db.GetUser(userEmail, customerId) if err != nil { return nil, err @@ -59,20 +58,23 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, cust log.Printf("user %s already exists", string(userEmail)) return nil, config.ErrUserAlreadyExists } - svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) - if err != nil { - return nil, err + //svgIdxInterface, err := n.Db.RandomSvgIdxInterface(kp) + //if err != nil { + // return nil, err + //} + svgIdxInterface := make(models.SvgIdInterface, 54) + for idx := range 54 { + svgIdxInterface[idx] = idx + 1 } signupSession, err := entities.NewSignupResetSession(userEmail, kp, customerId, svgIdxInterface, reset) if err != nil { return nil, err } - //n.SignupSessions[signupSession.Id] = *signupSession + //n.SignupSessions[signupSession.ID] = *signupSession if err := n.SignupSessionCache.Add(signupSession.Id.String(), *signupSession, sessionExpiration); err != nil { return nil, err } svgInterface, err := n.Db.GetSvgStringInterface(signupSession.LoginUserInterface.SvgId) - if err != nil { return nil, err } @@ -85,9 +87,41 @@ func (n *NKodeAPI) GenerateSignupResetInterface(userEmail models.UserEmail, cust return &resp, nil } -func (n *NKodeAPI) SetNKode(customerId models.CustomerId, sessionId models.SessionId, keySelection models.KeySelection) (models.IdxInterface, error) { - _, err := n.Db.GetCustomer(customerId) +func (n *NKodeAPI) GenerateSignupResetInterfaceRigged(userEmail models.UserEmail, customerId models.CustomerID, kp entities.KeypadDimension, reset bool) (*models.GenerateSignupResetInterfaceResp, error) { + user, err := n.Db.GetUser(userEmail, customerId) + if err != nil { + return nil, err + } + if user != nil && !reset { + log.Printf("user %s already exists", string(userEmail)) + return nil, config.ErrUserAlreadyExists + } + svgIdxInterface := make(models.SvgIdInterface, kp.TotalAttrs()) + for idx := range kp.TotalAttrs() { + svgIdxInterface[idx] = idx + } + signupSession, err := entities.NewSignupResetSessionRigged(userEmail, kp, customerId, svgIdxInterface, reset) + if err != nil { + return nil, err + } + if err := n.SignupSessionCache.Add(signupSession.Id.String(), *signupSession, sessionExpiration); err != nil { + return nil, err + } + svgInterface, err := n.Db.GetSvgStringInterface(signupSession.LoginUserInterface.SvgId) + if err != nil { + return nil, err + } + resp := models.GenerateSignupResetInterfaceResp{ + UserIdxInterface: signupSession.SetIdxInterface, + SvgInterface: svgInterface, + SessionId: uuid.UUID(signupSession.Id).String(), + Colors: signupSession.Colors, + } + return &resp, nil +} +func (n *NKodeAPI) SetNKode(customerId models.CustomerID, sessionId models.SessionId, keySelection models.KeySelection) (models.IdxInterface, error) { + _, err := n.Db.GetCustomer(customerId) if err != nil { return nil, err } @@ -109,7 +143,7 @@ func (n *NKodeAPI) SetNKode(customerId models.CustomerId, sessionId models.Sessi return confirmInterface, nil } -func (n *NKodeAPI) ConfirmNKode(customerId models.CustomerId, sessionId models.SessionId, keySelection models.KeySelection) error { +func (n *NKodeAPI) ConfirmNKode(customerId models.CustomerID, sessionId models.SessionId, keySelection models.KeySelection) error { session, exists := n.SignupSessionCache.Get(sessionId.String()) if !exists { log.Printf("session id does not exist %s", sessionId) @@ -144,7 +178,7 @@ func (n *NKodeAPI) ConfirmNKode(customerId models.CustomerId, sessionId models.S return err } -func (n *NKodeAPI) GetLoginInterface(userEmail models.UserEmail, customerId models.CustomerId) (*models.GetLoginInterfaceResp, error) { +func (n *NKodeAPI) GetLoginInterface(userEmail models.UserEmail, customerId models.CustomerID) (*models.GetLoginInterfaceResp, error) { user, err := n.Db.GetUser(userEmail, customerId) if err != nil { return nil, err @@ -167,7 +201,7 @@ func (n *NKodeAPI) GetLoginInterface(userEmail models.UserEmail, customerId mode return &resp, nil } -func (n *NKodeAPI) Login(customerId models.CustomerId, userEmail models.UserEmail, keySelection models.KeySelection) (*security.AuthenticationTokens, error) { +func (n *NKodeAPI) Login(customerId models.CustomerID, userEmail models.UserEmail, keySelection models.KeySelection) (*security.AuthenticationTokens, error) { customer, err := n.Db.GetCustomer(customerId) if err != nil { return nil, err @@ -184,7 +218,6 @@ func (n *NKodeAPI) Login(customerId models.CustomerId, userEmail models.UserEmai if err != nil { return nil, err } - if user.Renew { err = n.Db.RefreshUserPasscode(*user, passcode, customer.Attributes) if err != nil { @@ -207,7 +240,7 @@ func (n *NKodeAPI) Login(customerId models.CustomerId, userEmail models.UserEmai return &jwtToken, nil } -func (n *NKodeAPI) RenewAttributes(customerId models.CustomerId) error { +func (n *NKodeAPI) RenewAttributes(customerId models.CustomerID) error { return n.Db.Renew(customerId) } @@ -215,7 +248,7 @@ func (n *NKodeAPI) RandomSvgInterface() ([]string, error) { return n.Db.RandomSvgInterface(entities.KeypadMax) } -func (n *NKodeAPI) RefreshToken(userEmail models.UserEmail, customerId models.CustomerId, refreshToken string) (string, error) { +func (n *NKodeAPI) RefreshToken(userEmail models.UserEmail, customerId models.CustomerID, refreshToken string) (string, error) { user, err := n.Db.GetUser(userEmail, customerId) if err != nil { return "", err @@ -238,16 +271,14 @@ func (n *NKodeAPI) RefreshToken(userEmail models.UserEmail, customerId models.Cu return security.EncodeAndSignClaims(newAccessClaims) } -func (n *NKodeAPI) ResetNKode(userEmail models.UserEmail, customerId models.CustomerId) error { +func (n *NKodeAPI) ResetNKode(userEmail models.UserEmail, customerId models.CustomerID) error { user, err := n.Db.GetUser(userEmail, customerId) if err != nil { return fmt.Errorf("error getting user in rest nkode %v", err) } - if user == nil { return nil } - nkodeResetJwt, err := security.ResetNKodeToken(string(userEmail), uuid.UUID(customerId)) if err != nil { return err @@ -257,12 +288,12 @@ func (n *NKodeAPI) ResetNKode(userEmail models.UserEmail, customerId models.Cust frontendHost = config.FrontendHost } htmlBody := fmt.Sprintf("

Hello!

Click the link to reset your nKode.

Reset nKode", frontendHost, nkodeResetJwt) - email := email.Email{ + emailData := email.Email{ Sender: "no-reply@nkode.tech", Recipient: string(userEmail), Subject: "nKode Reset", Content: htmlBody, } - n.EmailQueue.AddEmail(email) + n.EmailQueue.AddEmail(emailData) return nil } diff --git a/internal/api/nkode_api_test.go b/internal/api/nkode_api_test.go index 4e28e08..a3e05d2 100644 --- a/internal/api/nkode_api_test.go +++ b/internal/api/nkode_api_test.go @@ -23,7 +23,7 @@ func TestNKodeAPI(t *testing.T) { sqliteDb, err := sqlite_queue.OpenSqliteDb(dbPath) assert.NoError(t, err) - queue, err := sqlite_queue.NewQueue(sqliteDb, ctx) + queue, err := sqlite_queue.NewQueue(ctx, sqliteDb) assert.NoError(t, err) queue.Start() defer func(queue *sqlite_queue.Queue) { @@ -31,7 +31,7 @@ func TestNKodeAPI(t *testing.T) { log.Fatal(err) } }(queue) - sqlitedb := repository.NewSqliteRepository(queue, ctx) + sqlitedb := repository.NewSqliteRepository(ctx, queue) testNKodeAPI(t, &sqlitedb) //if _, err := os.Stat(dbPath); err == nil { diff --git a/internal/entities/customer.go b/internal/entities/customer.go index fc6b87c..a4c1a40 100644 --- a/internal/entities/customer.go +++ b/internal/entities/customer.go @@ -11,7 +11,7 @@ import ( ) type Customer struct { - Id models.CustomerId + ID models.CustomerID NKodePolicy models.NKodePolicy Attributes CustomerAttributes } @@ -22,7 +22,7 @@ func NewCustomer(nkodePolicy models.NKodePolicy) (*Customer, error) { return nil, err } customer := Customer{ - Id: models.CustomerId(uuid.New()), + ID: models.CustomerID(uuid.New()), NKodePolicy: nkodePolicy, Attributes: *customerAttrs, } @@ -88,7 +88,7 @@ func (c *Customer) RenewKeys() ([]uint64, []uint64, error) { func (c *Customer) ToSqlcCreateCustomerParams() sqlc.CreateCustomerParams { return sqlc.CreateCustomerParams{ - ID: uuid.UUID(c.Id).String(), + ID: uuid.UUID(c.ID).String(), MaxNkodeLen: int64(c.NKodePolicy.MaxNkodeLen), MinNkodeLen: int64(c.NKodePolicy.MinNkodeLen), DistinctSets: int64(c.NKodePolicy.DistinctSets), diff --git a/internal/entities/user.go b/internal/entities/user.go index 0754f05..d79fe48 100644 --- a/internal/entities/user.go +++ b/internal/entities/user.go @@ -10,7 +10,7 @@ import ( type User struct { Id models.UserId - CustomerId models.CustomerId + CustomerId models.CustomerID Email models.UserEmail EncipheredPasscode models.EncipheredNKode Kp KeypadDimension @@ -137,7 +137,7 @@ func NewUser(customer Customer, userEmail string, passcodeIdx []int, ui UserInte CipherKeys: *newKeys, Interface: ui, Kp: kp, - CustomerId: customer.Id, + CustomerId: customer.ID, } return &newUser, nil } diff --git a/internal/entities/user_signup_session.go b/internal/entities/user_signup_session.go index 4d9f28f..99e0e56 100644 --- a/internal/entities/user_signup_session.go +++ b/internal/entities/user_signup_session.go @@ -13,7 +13,7 @@ import ( type UserSignSession struct { Id models.SessionId - CustomerId models.CustomerId + CustomerId models.CustomerID LoginUserInterface UserInterface Kp KeypadDimension SetIdxInterface models.IdxInterface @@ -25,7 +25,31 @@ type UserSignSession struct { Colors []models.RGBColor } -func NewSignupResetSession(userEmail models.UserEmail, kp KeypadDimension, customerId models.CustomerId, svgInterface models.SvgIdInterface, reset bool) (*UserSignSession, error) { +func NewSignupResetSessionRigged(userEmail models.UserEmail, kp KeypadDimension, customerId models.CustomerID, svgInterface models.SvgIdInterface, reset bool) (*UserSignSession, error) { + loginInterface, err := NewUserInterface(&kp, svgInterface) + if err != nil { + return nil, err + } + setIdxInterface := make(models.IdxInterface, 36) + for idx := range 36 { + setIdxInterface[idx] = idx + } + session := UserSignSession{ + Id: models.SessionId(uuid.New()), + CustomerId: customerId, + LoginUserInterface: *loginInterface, + SetIdxInterface: setIdxInterface, + ConfirmIdxInterface: nil, + SetKeySelection: nil, + UserEmail: userEmail, + Kp: kp, + Reset: reset, + Colors: []models.RGBColor{}, + } + return &session, nil +} + +func NewSignupResetSession(userEmail models.UserEmail, kp KeypadDimension, customerId models.CustomerID, svgInterface models.SvgIdInterface, reset bool) (*UserSignSession, error) { loginInterface, err := NewUserInterface(&kp, svgInterface) if err != nil { return nil, err @@ -46,7 +70,6 @@ func NewSignupResetSession(userEmail models.UserEmail, kp KeypadDimension, custo Reset: reset, Colors: colors, } - return &session, nil } diff --git a/internal/models/models.go b/internal/models/models.go index 49863da..9f7f757 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -90,13 +90,21 @@ type GetLoginInterfaceResp struct { type KeySelection []int -type CustomerId uuid.UUID +type CustomerID uuid.UUID -func CustomerIdToString(customerId CustomerId) string { +func CustomerIdToString(customerId CustomerID) string { customerUuid := uuid.UUID(customerId) return customerUuid.String() } +func CustomerIDFromString(customerID string) (CustomerID, error) { + id, err := uuid.Parse(customerID) + if err != nil { + return CustomerID{}, err + } + return CustomerID(id), nil +} + type SessionId uuid.UUID type UserId uuid.UUID diff --git a/internal/repository/customer_user_repository.go b/internal/repository/customer_user_repository.go index 3cdb620..2dd3b56 100644 --- a/internal/repository/customer_user_repository.go +++ b/internal/repository/customer_user_repository.go @@ -6,16 +6,17 @@ import ( ) type CustomerUserRepository interface { - GetCustomer(models.CustomerId) (*entities.Customer, error) - GetUser(models.UserEmail, models.CustomerId) (*entities.User, error) + GetCustomer(models.CustomerID) (*entities.Customer, error) + GetUser(models.UserEmail, models.CustomerID) (*entities.User, error) CreateCustomer(entities.Customer) error WriteNewUser(entities.User) error UpdateUserNKode(entities.User) error UpdateUserInterface(models.UserId, entities.UserInterface) error UpdateUserRefreshToken(models.UserId, string) error - Renew(models.CustomerId) error + Renew(models.CustomerID) error RefreshUserPasscode(entities.User, []int, entities.CustomerAttributes) error RandomSvgInterface(entities.KeypadDimension) ([]string, error) RandomSvgIdxInterface(entities.KeypadDimension) (models.SvgIdInterface, error) GetSvgStringInterface(models.SvgIdInterface) ([]string, error) + AddSVGIcon(svgStr string) (int64, error) } diff --git a/internal/repository/in_memory_db.go b/internal/repository/in_memory_db.go index 9869234..cbcecfe 100644 --- a/internal/repository/in_memory_db.go +++ b/internal/repository/in_memory_db.go @@ -8,28 +8,28 @@ import ( ) type InMemoryDb struct { - Customers map[models.CustomerId]entities.Customer + Customers map[models.CustomerID]entities.Customer Users map[models.UserId]entities.User userIdMap map[string]models.UserId } func NewInMemoryDb() InMemoryDb { return InMemoryDb{ - Customers: make(map[models.CustomerId]entities.Customer), + Customers: make(map[models.CustomerID]entities.Customer), Users: make(map[models.UserId]entities.User), userIdMap: make(map[string]models.UserId), } } -func (db *InMemoryDb) GetCustomer(id models.CustomerId) (*entities.Customer, error) { +func (db *InMemoryDb) GetCustomer(id models.CustomerID) (*entities.Customer, error) { customer, exists := db.Customers[id] if !exists { - return nil, errors.New(fmt.Sprintf("customer %s dne", customer.Id)) + return nil, errors.New(fmt.Sprintf("customer %s dne", customer.ID)) } return &customer, nil } -func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerId) (*entities.User, error) { +func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.CustomerID) (*entities.User, error) { key := userIdKey(customerId, username) userId, exists := db.userIdMap[key] if !exists { @@ -43,12 +43,12 @@ func (db *InMemoryDb) GetUser(username models.UserEmail, customerId models.Custo } func (db *InMemoryDb) CreateCustomer(customer entities.Customer) error { - _, exists := db.Customers[customer.Id] + _, exists := db.Customers[customer.ID] if exists { - return errors.New(fmt.Sprintf("can write customer %s; already exists", customer.Id)) + return errors.New(fmt.Sprintf("can write customer %s; already exists", customer.ID)) } - db.Customers[customer.Id] = customer + db.Customers[customer.ID] = customer return nil } @@ -86,7 +86,7 @@ func (db *InMemoryDb) UpdateUserRefreshToken(userId models.UserId, refreshToken return nil } -func (db *InMemoryDb) Renew(id models.CustomerId) error { +func (db *InMemoryDb) Renew(id models.CustomerID) error { customer, exists := db.Customers[id] if !exists { return errors.New(fmt.Sprintf("customer %s does not exist", id)) @@ -133,7 +133,7 @@ func (db *InMemoryDb) GetSvgStringInterface(idxs models.SvgIdInterface) ([]strin return make([]string, len(idxs)), nil } -func userIdKey(customerId models.CustomerId, username models.UserEmail) string { +func userIdKey(customerId models.CustomerID, username models.UserEmail) string { key := fmt.Sprintf("%s:%s", customerId, username) return key } diff --git a/internal/repository/sqlite_repository.go b/internal/repository/sqlite_repository.go index 88e2a7c..9740e19 100644 --- a/internal/repository/sqlite_repository.go +++ b/internal/repository/sqlite_repository.go @@ -21,7 +21,7 @@ type SqliteRepository struct { ctx context.Context } -func NewSqliteRepository(queue *sqlc.Queue, ctx context.Context) SqliteRepository { +func NewSqliteRepository(ctx context.Context, queue *sqlc.Queue) SqliteRepository { return SqliteRepository{ Queue: queue, ctx: ctx, @@ -29,27 +29,27 @@ func NewSqliteRepository(queue *sqlc.Queue, ctx context.Context) SqliteRepositor } func (d *SqliteRepository) CreateCustomer(c entities.Customer) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.CreateCustomerParams) if !ok { - return fmt.Errorf("invalid argument type: expected CreateCustomerParams") + return nil, fmt.Errorf("invalid argument type: expected CreateCustomerParams") } - return q.CreateCustomer(ctx, params) + return nil, q.CreateCustomer(ctx, params) } - return d.Queue.EnqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) + _, err := d.Queue.EnqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) + return err } func (d *SqliteRepository) WriteNewUser(u entities.User) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.CreateUserParams) if !ok { - return fmt.Errorf("invalid argument type: expected CreateUserParams") + return nil, fmt.Errorf("invalid argument type: expected CreateUserParams") } - return q.CreateUser(ctx, params) + return nil, q.CreateUser(ctx, params) } // Use the wrapped function in EnqueueWriteTx - renew := 0 if u.Renew { renew = 1 @@ -75,16 +75,17 @@ func (d *SqliteRepository) WriteNewUser(u entities.User) error { SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true}, } - return d.Queue.EnqueueWriteTx(queryFunc, params) + _, err := d.Queue.EnqueueWriteTx(queryFunc, params) + return err } func (d *SqliteRepository) UpdateUserNKode(u entities.User) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.UpdateUserParams) if !ok { - return fmt.Errorf("invalid argument type: expected UpdateUserParams") + return nil, fmt.Errorf("invalid argument type: expected UpdateUserParams") } - return q.UpdateUser(ctx, params) + return nil, q.UpdateUser(ctx, params) } // Use the wrapped function in EnqueueWriteTx renew := 0 @@ -109,33 +110,34 @@ func (d *SqliteRepository) UpdateUserNKode(u entities.User) error { IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), } - return d.Queue.EnqueueWriteTx(queryFunc, params) + _, err := d.Queue.EnqueueWriteTx(queryFunc, params) + return err } func (d *SqliteRepository) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.UpdateUserInterfaceParams) if !ok { - return fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams") + return nil, fmt.Errorf("invalid argument type: expected UpdateUserInterfaceParams") } - return q.UpdateUserInterface(ctx, params) + return nil, q.UpdateUserInterface(ctx, params) } params := sqlc.UpdateUserInterfaceParams{ IdxInterface: security.IntArrToByteArr(ui.IdxInterface), LastLogin: utils.TimeStamp(), ID: uuid.UUID(id).String(), } - - return d.Queue.EnqueueWriteTx(queryFunc, params) + _, err := d.Queue.EnqueueWriteTx(queryFunc, params) + return err } func (d *SqliteRepository) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.UpdateUserRefreshTokenParams) if !ok { - return fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken") + return nil, fmt.Errorf("invalid argument type: expected UpdateUserRefreshToken") } - return q.UpdateUserRefreshToken(ctx, params) + return nil, q.UpdateUserRefreshToken(ctx, params) } params := sqlc.UpdateUserRefreshTokenParams{ RefreshToken: sql.NullString{ @@ -144,21 +146,23 @@ func (d *SqliteRepository) UpdateUserRefreshToken(id models.UserId, refreshToken }, ID: uuid.UUID(id).String(), } - return d.Queue.EnqueueWriteTx(queryFunc, params) + _, err := d.Queue.EnqueueWriteTx(queryFunc, params) + return err } func (d *SqliteRepository) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.RenewCustomerParams) if !ok { - + return nil, fmt.Errorf("invalid argument type: expected RenewCustomerParams") } - return q.RenewCustomer(ctx, params) + return nil, q.RenewCustomer(ctx, params) } - return d.Queue.EnqueueWriteTx(queryFunc, renewParams) + _, err := d.Queue.EnqueueWriteTx(queryFunc, renewParams) + return err } -func (d *SqliteRepository) Renew(id models.CustomerId) error { +func (d *SqliteRepository) Renew(id models.CustomerID) error { setXor, attrXor, err := d.renewCustomer(id) if err != nil { return err @@ -168,19 +172,17 @@ func (d *SqliteRepository) Renew(id models.CustomerId) error { if err != nil { return err } - - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.RenewUserParams) if !ok { - return fmt.Errorf("invalid argument type: expected RenewUserParams") + return nil, fmt.Errorf("invalid argument type: expected RenewUserParams") } - return q.RenewUser(ctx, params) + return nil, q.RenewUser(ctx, params) } - for _, row := range userRenewRows { user := entities.User{ Id: models.UserIdFromString(row.ID), - CustomerId: models.CustomerId{}, + CustomerId: models.CustomerID{}, Email: "", EncipheredPasscode: models.EncipheredNKode{}, Kp: entities.KeypadDimension{ @@ -194,7 +196,6 @@ func (d *SqliteRepository) Renew(id models.CustomerId) error { Interface: entities.UserInterface{}, Renew: false, } - if err = user.RenewKeys(setXor, attrXor); err != nil { return err } @@ -204,14 +205,14 @@ func (d *SqliteRepository) Renew(id models.CustomerId) error { Renew: 1, ID: uuid.UUID(user.Id).String(), } - if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { + if _, err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { return err } } return nil } -func (d *SqliteRepository) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) { +func (d *SqliteRepository) renewCustomer(id models.CustomerID) ([]uint64, []uint64, error) { customer, err := d.GetCustomer(id) if err != nil { return nil, nil, err @@ -220,21 +221,19 @@ func (d *SqliteRepository) renewCustomer(id models.CustomerId) ([]uint64, []uint if err != nil { return nil, nil, err } - - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.RenewCustomerParams) if !ok { - return fmt.Errorf("invalid argument type: expected RenewCustomerParams") + return nil, fmt.Errorf("invalid argument type: expected RenewCustomerParams") } - return q.RenewCustomer(ctx, params) + return nil, q.RenewCustomer(ctx, params) } params := sqlc.RenewCustomerParams{ AttributeValues: security.Uint64ArrToByteArr(customer.Attributes.AttrVals), SetValues: security.Uint64ArrToByteArr(customer.Attributes.SetVals), - ID: uuid.UUID(customer.Id).String(), + ID: uuid.UUID(customer.ID).String(), } - - if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { + if _, err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { return nil, nil, err } return setXor, attrXor, nil @@ -244,12 +243,12 @@ func (d *SqliteRepository) RefreshUserPasscode(user entities.User, passcodeIdx [ if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil { return err } - queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { params, ok := args.(sqlc.RefreshUserPasscodeParams) if !ok { - return fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams") + return nil, fmt.Errorf("invalid argument type: expected RefreshUserPasscodeParams") } - return q.RefreshUserPasscode(ctx, params) + return nil, q.RefreshUserPasscode(ctx, params) } params := sqlc.RefreshUserPasscodeParams{ Renew: 0, @@ -262,17 +261,17 @@ func (d *SqliteRepository) RefreshUserPasscode(user entities.User, passcodeIdx [ Salt: user.CipherKeys.Salt, ID: uuid.UUID(user.Id).String(), } - return d.Queue.EnqueueWriteTx(queryFunc, params) + _, err := d.Queue.EnqueueWriteTx(queryFunc, params) + return err } -func (d *SqliteRepository) GetCustomer(id models.CustomerId) (*entities.Customer, error) { +func (d *SqliteRepository) GetCustomer(id models.CustomerID) (*entities.Customer, error) { customer, err := d.Queue.Queries.GetCustomer(d.ctx, uuid.UUID(id).String()) if err != nil { return nil, err } - return &entities.Customer{ - Id: id, + ID: id, NKodePolicy: models.NKodePolicy{ MaxNkodeLen: int(customer.MaxNkodeLen), MinNkodeLen: int(customer.MinNkodeLen), @@ -285,7 +284,7 @@ func (d *SqliteRepository) GetCustomer(id models.CustomerId) (*entities.Customer }, nil } -func (d *SqliteRepository) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { +func (d *SqliteRepository) GetUser(email models.UserEmail, customerId models.CustomerID) (*entities.User, error) { userRow, err := d.Queue.Queries.GetUser(d.ctx, sqlc.GetUserParams{ Email: string(email), CustomerID: uuid.UUID(customerId).String(), @@ -296,12 +295,10 @@ func (d *SqliteRepository) GetUser(email models.UserEmail, customerId models.Cus } return nil, fmt.Errorf("failed to get user: %w", err) } - kp := entities.KeypadDimension{ AttrsPerKey: int(userRow.AttributesPerKey), NumbOfKeys: int(userRow.NumberOfKeys), } - renew := false if userRow.Renew == 1 { renew = true @@ -351,6 +348,25 @@ func (d *SqliteRepository) GetSvgStringInterface(idxs models.SvgIdInterface) ([] return d.getSvgsById(idxs) } +func (d *SqliteRepository) AddSVGIcon(svgStr string) (int64, error) { + queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) (any, error) { + params, ok := args.(string) + if !ok { + return nil, fmt.Errorf("invalid argument type: expected string") + } + return q.AddSVGIcon(ctx, params) + } + svgID, err := d.Queue.EnqueueWriteTx(queryFunc, svgStr) + if err != nil { + return -1, err + } + svgIDInt64, ok := svgID.(int64) + if !ok { + return -1, errors.New("svgID in DB isn't int64") + } + return svgIDInt64, nil +} + func (d *SqliteRepository) getSvgsById(ids []int) ([]string, error) { svgs := make([]string, len(ids)) for idx, id := range ids { diff --git a/internal/repository/sqlite_repository_test.go b/internal/repository/sqlite_repository_test.go index 6db6099..3ee74af 100644 --- a/internal/repository/sqlite_repository_test.go +++ b/internal/repository/sqlite_repository_test.go @@ -17,12 +17,12 @@ func TestNewSqliteDB(t *testing.T) { sqliteDb, err := sqlite_queue.OpenSqliteDb(dbPath) assert.NoError(t, err) - queue, err := sqlite_queue.NewQueue(sqliteDb, ctx) + queue, err := sqlite_queue.NewQueue(ctx, sqliteDb) assert.NoError(t, err) queue.Start() defer queue.Stop() - db := NewSqliteRepository(queue, ctx) + db := NewSqliteRepository(ctx, queue) assert.NoError(t, err) testSignupLoginRenew(t, &db) testSqliteDBRandomSvgInterface(t, &db) @@ -34,7 +34,7 @@ func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { assert.NoError(t, err) err = db.CreateCustomer(*customerOrig) assert.NoError(t, err) - customer, err := db.GetCustomer(customerOrig.Id) + customer, err := db.GetCustomer(customerOrig.ID) assert.NoError(t, err) assert.Equal(t, customerOrig, customer) username := "test_user@example.com" @@ -47,11 +47,11 @@ func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { assert.NoError(t, err) err = db.WriteNewUser(*userOrig) assert.NoError(t, err) - user, err := db.GetUser(models.UserEmail(username), customer.Id) + user, err := db.GetUser(models.UserEmail(username), customer.ID) assert.NoError(t, err) assert.Equal(t, userOrig, user) - err = db.Renew(customer.Id) + err = db.Renew(customer.ID) assert.NoError(t, err) } diff --git a/internal/sqlc/query.sql.go b/internal/sqlc/query.sql.go index c11071c..f08dbec 100644 --- a/internal/sqlc/query.sql.go +++ b/internal/sqlc/query.sql.go @@ -10,6 +10,19 @@ import ( "database/sql" ) +const addSVGIcon = `-- name: AddSVGIcon :one +INSERT INTO svg_icon (svg) +VALUES (?) + RETURNING id +` + +func (q *Queries) AddSVGIcon(ctx context.Context, svg string) (int64, error) { + row := q.db.QueryRowContext(ctx, addSVGIcon, svg) + var id int64 + err := row.Scan(&id) + return id, err +} + const createCustomer = `-- name: CreateCustomer :exec INSERT INTO customer ( id diff --git a/internal/sqlc/sqlite_queue.go b/internal/sqlc/sqlite_queue.go index 20c17e6..b4760e1 100644 --- a/internal/sqlc/sqlite_queue.go +++ b/internal/sqlc/sqlite_queue.go @@ -10,12 +10,13 @@ import ( const writeBufferSize = 100 -type SqlcGeneric func(*Queries, context.Context, any) error +type GenericQuery func(*Queries, context.Context, any) (any, error) type WriteTx struct { - ErrChan chan error - Query SqlcGeneric - Args interface{} + ErrChan chan error + ReturnChan chan any + Query GenericQuery + Args any } type Queue struct { @@ -27,7 +28,7 @@ type Queue struct { cancel context.CancelFunc } -func NewQueue(sqlDb *sql.DB, ctx context.Context) (*Queue, error) { +func NewQueue(ctx context.Context, sqlDb *sql.DB) (*Queue, error) { ctx, cancel := context.WithCancel(context.Background()) sqldb := &Queue{ Queries: New(sqlDb), @@ -42,15 +43,19 @@ func NewQueue(sqlDb *sql.DB, ctx context.Context) (*Queue, error) { func (d *Queue) Start() { d.wg.Add(1) - defer d.wg.Done() go func() { + // TODO: I think this might be a naive approach. + defer d.wg.Done() for { select { case <-d.ctx.Done(): return case writeTx := <-d.WriteQueue: - err := writeTx.Query(d.Queries, d.ctx, writeTx.Args) + ret, err := writeTx.Query(d.Queries, d.ctx, writeTx.Args) writeTx.ErrChan <- err + writeTx.ReturnChan <- ret + close(writeTx.ErrChan) + close(writeTx.ReturnChan) } } }() @@ -63,21 +68,24 @@ func (d *Queue) Stop() error { return d.Db.Close() } -func (d *Queue) EnqueueWriteTx(queryFunc SqlcGeneric, args any) error { +func (d *Queue) EnqueueWriteTx(queryFunc GenericQuery, args any) (any, error) { select { case <-d.ctx.Done(): - return errors.New("database is shutting down") + return nil, errors.New("database is shutting down") default: } - errChan := make(chan error, 1) + retChan := make(chan any, 1) writeTx := WriteTx{ - Query: queryFunc, - Args: args, - ErrChan: errChan, + Query: queryFunc, + Args: args, + ErrChan: errChan, + ReturnChan: retChan, } d.WriteQueue <- writeTx - return <-errChan + err := <-errChan + val := <-retChan + return val, err } func OpenSqliteDb(dbPath string) (*sql.DB, error) { diff --git a/scripts/bash/rebuild_db.sh b/scripts/bash/rebuild_db.sh new file mode 100755 index 0000000..b6e5b94 --- /dev/null +++ b/scripts/bash/rebuild_db.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +sqlite_db="$HOME/databases/demo.db" +db_schema="../../sqlite/schema.sql" +# svg_path="$HOME/svgs/flaticon_colored_svgs" +# svg_path="$HOME/svgs/flaticon_colored_pngs" +#svg_path="$HOME/icons" +svg_path="$HOME/svgs/warfighter_icons" + + +# remove existing test database +if [ -f "$sqlite_db" ]; then + echo "Removing existing test database at $sqlite_db" + rm "$sqlite_db" +else + echo "No existing test database found at $sqlite_db" +fi + +# rebuild database +sqlite3 "$sqlite_db" < "$db_schema" + +cli="../../bin/cli" +# build go cli +echo "building cli" +go build -o $cli ../../cmd/cli/main.go + +# build db +echo "building db" +$cli build-db -db-path "$sqlite_db" -img-path "$svg_path" + +# create customer +echo "creating customer" +customer_id="ed9ed6e0-082c-4b57-8d8c-f00ed6493457" +$cli create-customer -customer-id "$customer_id" -db-path "$sqlite_db" + +## create admin user +#user_email="donovan.a.kelly@pm.me" +#keypad_path="$HOME/svgs/my_icons/" +#$nkode_cli add-user \ +# -img-path "$keypad_path" \ +# -img-type "svg" \ +# -customer-id "$customer_id" \ +# -user-email "$user_email" \ +# -attrs-per-key 9 -numb-of-keys 6 \ +# -db-path "$sqlite_db" \ +# -role admin \ +# -nkode-icons ae-86.svg,arkansas.svg,banana-slug.svg,blockchain.svg +# diff --git a/sqlite/embed.go b/sqlite/embed.go new file mode 100644 index 0000000..dec0d88 --- /dev/null +++ b/sqlite/embed.go @@ -0,0 +1,6 @@ +package sqlite + +import "embed" + +//go:embed schema.sql +var FS embed.FS diff --git a/sqlite/query.sql b/sqlite/query.sql index b3c50cf..db49b63 100644 --- a/sqlite/query.sql +++ b/sqlite/query.sql @@ -134,3 +134,8 @@ WHERE id = ?; -- name: GetSvgCount :one SELECT COUNT(*) as count FROM svg_icon; + +-- name: AddSVGIcon :one +INSERT INTO svg_icon (svg) +VALUES (?) + RETURNING id;