From f1121711451aead547e65cb3d81321798c403cc6 Mon Sep 17 00:00:00 2001 From: Donovan Date: Thu, 2 Jan 2025 17:32:33 -0600 Subject: [PATCH] refactor sqlite queue --- cmd/main.go | 35 +++- internal/api/nkode_api.go | 6 +- internal/api/nkode_api_test.go | 30 +++- internal/email/queue.go | 2 +- internal/email/queue_test.go | 2 +- .../customer_user_repository.go | 2 +- internal/{db => repository}/in_memory_db.go | 2 +- .../sqlite-init/json/academicons.json | 0 .../sqlite-init/json/akar-icons.json | 0 .../sqlite-init/json/ant-design.json | 0 .../sqlite-init/json/arcticons.json | 0 .../sqlite-init/json/basil.json | 0 .../sqlite-init/json/bitcoin-icons.json | 0 .../sqlite-init/sqlite_init.go | 0 .../sqlite_repository.go} | 158 +++++------------- .../sqlite_repository_test.go} | 27 +-- internal/sqlc/sqlite_queue.go | 93 +++++++++++ 17 files changed, 204 insertions(+), 153 deletions(-) rename internal/{db => repository}/customer_user_repository.go (97%) rename internal/{db => repository}/in_memory_db.go (99%) rename internal/{db => repository}/sqlite-init/json/academicons.json (100%) rename internal/{db => repository}/sqlite-init/json/akar-icons.json (100%) rename internal/{db => repository}/sqlite-init/json/ant-design.json (100%) rename internal/{db => repository}/sqlite-init/json/arcticons.json (100%) rename internal/{db => repository}/sqlite-init/json/basil.json (100%) rename internal/{db => repository}/sqlite-init/json/bitcoin-icons.json (100%) rename internal/{db => repository}/sqlite-init/sqlite_init.go (100%) rename internal/{db/sqlite_db.go => repository/sqlite_repository.go} (71%) rename internal/{db/sqlite_db_test.go => repository/sqlite_repository_test.go} (76%) create mode 100644 internal/sqlc/sqlite_queue.go diff --git a/cmd/main.go b/cmd/main.go index c119843..2f2f19c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,14 +1,17 @@ package main import ( + "context" + "database/sql" "fmt" "github.com/google/uuid" httpSwagger "github.com/swaggo/http-swagger" _ "go-nkode/docs" "go-nkode/internal/api" - "go-nkode/internal/db" "go-nkode/internal/email" "go-nkode/internal/models" + "go-nkode/internal/repository" + sqliteQueue "go-nkode/internal/sqlc" "log" "net/http" "os" @@ -37,24 +40,42 @@ const ( // @securityDefinitions.apiKey ApiKeyAuth // @in header // @name Authorization - func main() { dbPath := os.Getenv("SQLITE_DB") if dbPath == "" { - log.Fatalf("SQLITE_DB=/path/to/nkode.db not set") + log.Fatal("SQLITE_DB=/path/to/nkode.db not set") } - sqlitedb, err := db.NewSqliteDB(dbPath) + + sqliteDb, err := sql.Open("sqlite3", dbPath) if err != nil { - log.Fatalf("%v", err) + log.Fatalf("failed to open database: %v", err) } - defer sqlitedb.Close() + + if err := sqliteDb.Ping(); err != nil { + log.Fatalf("failed to connect to database: %v", err) + } + + ctx := context.Background() + queue, err := sqliteQueue.NewQueue(sqliteDb, ctx) + if err != nil { + log.Fatal(err) + } + queue.Start() + + defer func(queue *sqliteQueue.Queue) { + if err := queue.Stop(); err != nil { + log.Fatal(err) + } + }(queue) sesClient := email.NewSESClient() emailQueue := email.NewEmailQueue(emailQueueBufferSize, maxEmailsPerSecond, &sesClient) emailQueue.Start() defer emailQueue.Stop() - nkodeApi := api.NewNKodeAPI(sqlitedb, emailQueue) + sqlitedb := repository.NewSqliteRepository(queue, ctx) + nkodeApi := api.NewNKodeAPI(&sqlitedb, emailQueue) + AddDefaultCustomer(nkodeApi) handler := api.NKodeHandler{Api: nkodeApi} diff --git a/internal/api/nkode_api.go b/internal/api/nkode_api.go index fb280a4..814b8c8 100644 --- a/internal/api/nkode_api.go +++ b/internal/api/nkode_api.go @@ -5,10 +5,10 @@ import ( "github.com/google/uuid" "github.com/patrickmn/go-cache" "go-nkode/config" - "go-nkode/internal/db" "go-nkode/internal/email" "go-nkode/internal/entities" "go-nkode/internal/models" + "go-nkode/internal/repository" "go-nkode/internal/security" "log" "os" @@ -21,12 +21,12 @@ const ( ) type NKodeAPI struct { - Db db.CustomerUserRepository + Db repository.CustomerUserRepository SignupSessionCache *cache.Cache EmailQueue *email.Queue } -func NewNKodeAPI(db db.CustomerUserRepository, queue *email.Queue) NKodeAPI { +func NewNKodeAPI(db repository.CustomerUserRepository, queue *email.Queue) NKodeAPI { return NKodeAPI{ Db: db, EmailQueue: queue, diff --git a/internal/api/nkode_api_test.go b/internal/api/nkode_api_test.go index 8200d1c..4e28e08 100644 --- a/internal/api/nkode_api_test.go +++ b/internal/api/nkode_api_test.go @@ -1,12 +1,15 @@ package api import ( + "context" "github.com/stretchr/testify/assert" - "go-nkode/internal/db" "go-nkode/internal/email" "go-nkode/internal/entities" "go-nkode/internal/models" + "go-nkode/internal/repository" "go-nkode/internal/security" + sqlite_queue "go-nkode/internal/sqlc" + "log" "os" "testing" ) @@ -15,22 +18,31 @@ func TestNKodeAPI(t *testing.T) { //db1 := NewInMemoryDb() //testNKodeAPI(t, &db1) - dbFile := os.Getenv("TEST_DB") - - db2, err := db.NewSqliteDB(dbFile) + dbPath := os.Getenv("TEST_DB") + ctx := context.Background() + sqliteDb, err := sqlite_queue.OpenSqliteDb(dbPath) assert.NoError(t, err) - defer db2.Close() - testNKodeAPI(t, db2) - //if _, err := os.Stat(dbFile); err == nil { - // err = os.Remove(dbFile) + queue, err := sqlite_queue.NewQueue(sqliteDb, ctx) + assert.NoError(t, err) + queue.Start() + defer func(queue *sqlite_queue.Queue) { + if err := queue.Stop(); err != nil { + log.Fatal(err) + } + }(queue) + sqlitedb := repository.NewSqliteRepository(queue, ctx) + testNKodeAPI(t, &sqlitedb) + + //if _, err := os.Stat(dbPath); err == nil { + // err = os.Remove(dbPath) // assert.NoError(t, err) //} else { // assert.NoError(t, err) //} } -func testNKodeAPI(t *testing.T, db db.CustomerUserRepository) { +func testNKodeAPI(t *testing.T, db repository.CustomerUserRepository) { bufferSize := 100 emailsPerSec := 14 testClient := email.TestEmailClient{} diff --git a/internal/email/queue.go b/internal/email/queue.go index 597694b..4bd2d43 100644 --- a/internal/email/queue.go +++ b/internal/email/queue.go @@ -163,6 +163,6 @@ func (q *Queue) Stop() { q.stop = true // Wait for all emails to be processed q.wg.Wait() - // Close the email queue + // Stop the email queue close(q.emailQueue) } diff --git a/internal/email/queue_test.go b/internal/email/queue_test.go index a0733d0..9baee84 100644 --- a/internal/email/queue_test.go +++ b/internal/email/queue_test.go @@ -22,7 +22,7 @@ func TestEmailQueue(t *testing.T) { } queue.AddEmail(email) } - // Close the queue after all emails are processed + // Stop the queue after all emails are processed queue.Stop() assert.Equal(t, queue.FailedSendCount, 0) diff --git a/internal/db/customer_user_repository.go b/internal/repository/customer_user_repository.go similarity index 97% rename from internal/db/customer_user_repository.go rename to internal/repository/customer_user_repository.go index 3b59d51..3cdb620 100644 --- a/internal/db/customer_user_repository.go +++ b/internal/repository/customer_user_repository.go @@ -1,4 +1,4 @@ -package db +package repository import ( "go-nkode/internal/entities" diff --git a/internal/db/in_memory_db.go b/internal/repository/in_memory_db.go similarity index 99% rename from internal/db/in_memory_db.go rename to internal/repository/in_memory_db.go index a73e413..9869234 100644 --- a/internal/db/in_memory_db.go +++ b/internal/repository/in_memory_db.go @@ -1,4 +1,4 @@ -package db +package repository import ( "errors" diff --git a/internal/db/sqlite-init/json/academicons.json b/internal/repository/sqlite-init/json/academicons.json similarity index 100% rename from internal/db/sqlite-init/json/academicons.json rename to internal/repository/sqlite-init/json/academicons.json diff --git a/internal/db/sqlite-init/json/akar-icons.json b/internal/repository/sqlite-init/json/akar-icons.json similarity index 100% rename from internal/db/sqlite-init/json/akar-icons.json rename to internal/repository/sqlite-init/json/akar-icons.json diff --git a/internal/db/sqlite-init/json/ant-design.json b/internal/repository/sqlite-init/json/ant-design.json similarity index 100% rename from internal/db/sqlite-init/json/ant-design.json rename to internal/repository/sqlite-init/json/ant-design.json diff --git a/internal/db/sqlite-init/json/arcticons.json b/internal/repository/sqlite-init/json/arcticons.json similarity index 100% rename from internal/db/sqlite-init/json/arcticons.json rename to internal/repository/sqlite-init/json/arcticons.json diff --git a/internal/db/sqlite-init/json/basil.json b/internal/repository/sqlite-init/json/basil.json similarity index 100% rename from internal/db/sqlite-init/json/basil.json rename to internal/repository/sqlite-init/json/basil.json diff --git a/internal/db/sqlite-init/json/bitcoin-icons.json b/internal/repository/sqlite-init/json/bitcoin-icons.json similarity index 100% rename from internal/db/sqlite-init/json/bitcoin-icons.json rename to internal/repository/sqlite-init/json/bitcoin-icons.json diff --git a/internal/db/sqlite-init/sqlite_init.go b/internal/repository/sqlite-init/sqlite_init.go similarity index 100% rename from internal/db/sqlite-init/sqlite_init.go rename to internal/repository/sqlite-init/sqlite_init.go diff --git a/internal/db/sqlite_db.go b/internal/repository/sqlite_repository.go similarity index 71% rename from internal/db/sqlite_db.go rename to internal/repository/sqlite_repository.go index 16b48b4..88e2a7c 100644 --- a/internal/db/sqlite_db.go +++ b/internal/repository/sqlite_repository.go @@ -1,4 +1,4 @@ -package db +package repository import ( "context" @@ -14,82 +14,21 @@ import ( "go-nkode/internal/sqlc" "go-nkode/internal/utils" "log" - "sync" ) -const writeBufferSize = 100 - -type sqlcGeneric func(*sqlc.Queries, context.Context, any) error - -// WriteTx represents a write transaction -type WriteTx struct { - ErrChan chan error - Query sqlcGeneric - Args interface{} +type SqliteRepository struct { + Queue *sqlc.Queue + ctx context.Context } -// SqliteDB represents the SQLite database connection and write queue -type SqliteDB struct { - queries *sqlc.Queries - db *sql.DB - writeQueue chan WriteTx - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc -} - -// NewSqliteDB initializes a new SqliteDB instance -func NewSqliteDB(path string) (*SqliteDB, error) { - if path == "" { - return nil, errors.New("database path is required") - } - - db, err := sql.Open("sqlite3", path) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - sqldb := &SqliteDB{ - queries: sqlc.New(db), - db: db, - writeQueue: make(chan WriteTx, writeBufferSize), - ctx: ctx, - cancel: cancel, - } - - sqldb.wg.Add(1) - go sqldb.processWriteQueue() - - return sqldb, nil -} - -// processWriteQueue handles write transactions from the queue -func (d *SqliteDB) processWriteQueue() { - defer d.wg.Done() - for { - select { - case <-d.ctx.Done(): - return - case writeTx := <-d.writeQueue: - err := writeTx.Query(d.queries, d.ctx, writeTx.Args) - writeTx.ErrChan <- err - } +func NewSqliteRepository(queue *sqlc.Queue, ctx context.Context) SqliteRepository { + return SqliteRepository{ + Queue: queue, + ctx: ctx, } } -func (d *SqliteDB) Close() error { - d.cancel() - d.wg.Wait() - close(d.writeQueue) - return d.db.Close() -} - -func (d *SqliteDB) CreateCustomer(c entities.Customer) error { +func (d *SqliteRepository) CreateCustomer(c entities.Customer) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.CreateCustomerParams) if !ok { @@ -98,10 +37,10 @@ func (d *SqliteDB) CreateCustomer(c entities.Customer) error { return q.CreateCustomer(ctx, params) } - return d.enqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) + return d.Queue.EnqueueWriteTx(queryFunc, c.ToSqlcCreateCustomerParams()) } -func (d *SqliteDB) WriteNewUser(u entities.User) error { +func (d *SqliteRepository) WriteNewUser(u entities.User) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.CreateUserParams) if !ok { @@ -109,7 +48,7 @@ func (d *SqliteDB) WriteNewUser(u entities.User) error { } return q.CreateUser(ctx, params) } - // Use the wrapped function in enqueueWriteTx + // Use the wrapped function in EnqueueWriteTx renew := 0 if u.Renew { @@ -136,10 +75,10 @@ func (d *SqliteDB) WriteNewUser(u entities.User) error { SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), CreatedAt: sql.NullString{String: utils.TimeStamp(), Valid: true}, } - return d.enqueueWriteTx(queryFunc, params) + return d.Queue.EnqueueWriteTx(queryFunc, params) } -func (d *SqliteDB) UpdateUserNKode(u entities.User) error { +func (d *SqliteRepository) UpdateUserNKode(u entities.User) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserParams) if !ok { @@ -147,7 +86,7 @@ func (d *SqliteDB) UpdateUserNKode(u entities.User) error { } return q.UpdateUser(ctx, params) } - // Use the wrapped function in enqueueWriteTx + // Use the wrapped function in EnqueueWriteTx renew := 0 if u.Renew { renew = 1 @@ -170,10 +109,10 @@ func (d *SqliteDB) UpdateUserNKode(u entities.User) error { IdxInterface: security.IntArrToByteArr(u.Interface.IdxInterface), SvgIDInterface: security.IntArrToByteArr(u.Interface.SvgId), } - return d.enqueueWriteTx(queryFunc, params) + return d.Queue.EnqueueWriteTx(queryFunc, params) } -func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { +func (d *SqliteRepository) UpdateUserInterface(id models.UserId, ui entities.UserInterface) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserInterfaceParams) if !ok { @@ -187,10 +126,10 @@ func (d *SqliteDB) UpdateUserInterface(id models.UserId, ui entities.UserInterfa ID: uuid.UUID(id).String(), } - return d.enqueueWriteTx(queryFunc, params) + return d.Queue.EnqueueWriteTx(queryFunc, params) } -func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { +func (d *SqliteRepository) UpdateUserRefreshToken(id models.UserId, refreshToken string) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.UpdateUserRefreshTokenParams) if !ok { @@ -205,10 +144,10 @@ func (d *SqliteDB) UpdateUserRefreshToken(id models.UserId, refreshToken string) }, ID: uuid.UUID(id).String(), } - return d.enqueueWriteTx(queryFunc, params) + return d.Queue.EnqueueWriteTx(queryFunc, params) } -func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { +func (d *SqliteRepository) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { queryFunc := func(q *sqlc.Queries, ctx context.Context, args any) error { params, ok := args.(sqlc.RenewCustomerParams) if !ok { @@ -216,16 +155,16 @@ func (d *SqliteDB) RenewCustomer(renewParams sqlc.RenewCustomerParams) error { } return q.RenewCustomer(ctx, params) } - return d.enqueueWriteTx(queryFunc, renewParams) + return d.Queue.EnqueueWriteTx(queryFunc, renewParams) } -func (d *SqliteDB) Renew(id models.CustomerId) error { +func (d *SqliteRepository) Renew(id models.CustomerId) error { setXor, attrXor, err := d.renewCustomer(id) if err != nil { return err } customerId := models.CustomerIdToString(id) - userRenewRows, err := d.queries.GetUserRenew(d.ctx, customerId) + userRenewRows, err := d.Queue.Queries.GetUserRenew(d.ctx, customerId) if err != nil { return err } @@ -265,14 +204,14 @@ func (d *SqliteDB) Renew(id models.CustomerId) error { Renew: 1, ID: uuid.UUID(user.Id).String(), } - if err = d.enqueueWriteTx(queryFunc, params); err != nil { + if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { return err } } return nil } -func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) { +func (d *SqliteRepository) renewCustomer(id models.CustomerId) ([]uint64, []uint64, error) { customer, err := d.GetCustomer(id) if err != nil { return nil, nil, err @@ -295,13 +234,13 @@ func (d *SqliteDB) renewCustomer(id models.CustomerId) ([]uint64, []uint64, erro ID: uuid.UUID(customer.Id).String(), } - if err = d.enqueueWriteTx(queryFunc, params); err != nil { + if err = d.Queue.EnqueueWriteTx(queryFunc, params); err != nil { return nil, nil, err } return setXor, attrXor, nil } -func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { +func (d *SqliteRepository) RefreshUserPasscode(user entities.User, passcodeIdx []int, customerAttr entities.CustomerAttributes) error { if err := user.RefreshPasscode(passcodeIdx, customerAttr); err != nil { return err } @@ -323,11 +262,11 @@ func (d *SqliteDB) RefreshUserPasscode(user entities.User, passcodeIdx []int, cu Salt: user.CipherKeys.Salt, ID: uuid.UUID(user.Id).String(), } - return d.enqueueWriteTx(queryFunc, params) + return d.Queue.EnqueueWriteTx(queryFunc, params) } -func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) { - customer, err := d.queries.GetCustomer(d.ctx, uuid.UUID(id).String()) +func (d *SqliteRepository) GetCustomer(id models.CustomerId) (*entities.Customer, error) { + customer, err := d.Queue.Queries.GetCustomer(d.ctx, uuid.UUID(id).String()) if err != nil { return nil, err } @@ -346,8 +285,8 @@ func (d *SqliteDB) GetCustomer(id models.CustomerId) (*entities.Customer, error) }, nil } -func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { - userRow, err := d.queries.GetUser(d.ctx, sqlc.GetUserParams{ +func (d *SqliteRepository) GetUser(email models.UserEmail, customerId models.CustomerId) (*entities.User, error) { + userRow, err := d.Queue.Queries.GetUser(d.ctx, sqlc.GetUserParams{ Email: string(email), CustomerID: uuid.UUID(customerId).String(), }) @@ -396,7 +335,7 @@ func (d *SqliteDB) GetUser(email models.UserEmail, customerId models.CustomerId) return &user, nil } -func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { +func (d *SqliteRepository) RandomSvgInterface(kp entities.KeypadDimension) ([]string, error) { ids, err := d.getRandomIds(kp.TotalAttrs()) if err != nil { return nil, err @@ -404,18 +343,18 @@ func (d *SqliteDB) RandomSvgInterface(kp entities.KeypadDimension) ([]string, er return d.getSvgsById(ids) } -func (d *SqliteDB) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { +func (d *SqliteRepository) RandomSvgIdxInterface(kp entities.KeypadDimension) (models.SvgIdInterface, error) { return d.getRandomIds(kp.TotalAttrs()) } -func (d *SqliteDB) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) { +func (d *SqliteRepository) GetSvgStringInterface(idxs models.SvgIdInterface) ([]string, error) { return d.getSvgsById(idxs) } -func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { +func (d *SqliteRepository) getSvgsById(ids []int) ([]string, error) { svgs := make([]string, len(ids)) for idx, id := range ids { - svg, err := d.queries.GetSvgId(d.ctx, int64(id)) + svg, err := d.Queue.Queries.GetSvgId(d.ctx, int64(id)) if err != nil { return nil, err } @@ -424,25 +363,8 @@ func (d *SqliteDB) getSvgsById(ids []int) ([]string, error) { return svgs, nil } -func (d *SqliteDB) enqueueWriteTx(queryFunc sqlcGeneric, args any) error { - select { - case <-d.ctx.Done(): - return errors.New("database is shutting down") - default: - } - - errChan := make(chan error, 1) - writeTx := WriteTx{ - Query: queryFunc, - Args: args, - ErrChan: errChan, - } - d.writeQueue <- writeTx - return <-errChan -} - -func (d *SqliteDB) getRandomIds(count int) ([]int, error) { - tx, err := d.db.Begin() +func (d *SqliteRepository) getRandomIds(count int) ([]int, error) { + tx, err := d.Queue.Db.Begin() if err != nil { log.Print(err) return nil, config.ErrSqliteTx diff --git a/internal/db/sqlite_db_test.go b/internal/repository/sqlite_repository_test.go similarity index 76% rename from internal/db/sqlite_db_test.go rename to internal/repository/sqlite_repository_test.go index 0c6d11e..6db6099 100644 --- a/internal/db/sqlite_db_test.go +++ b/internal/repository/sqlite_repository_test.go @@ -1,28 +1,31 @@ -package db +package repository import ( + "context" "github.com/stretchr/testify/assert" "go-nkode/internal/entities" "go-nkode/internal/models" + sqlite_queue "go-nkode/internal/sqlc" "os" "testing" ) func TestNewSqliteDB(t *testing.T) { - dbFile := os.Getenv("TEST_DB") + dbPath := os.Getenv("TEST_DB") // sql_driver.MakeTables(dbFile) - db, err := NewSqliteDB(dbFile) + ctx := context.Background() + sqliteDb, err := sqlite_queue.OpenSqliteDb(dbPath) assert.NoError(t, err) - defer db.Close() - testSignupLoginRenew(t, db) - testSqliteDBRandomSvgInterface(t, db) - // if _, err := os.Stat(dbFile); err == nil { - // err = os.Remove(dbFile) - // assert.NoError(t, err) - // } else { - // assert.NoError(t, err) - // } + queue, err := sqlite_queue.NewQueue(sqliteDb, ctx) + assert.NoError(t, err) + + queue.Start() + defer queue.Stop() + db := NewSqliteRepository(queue, ctx) + assert.NoError(t, err) + testSignupLoginRenew(t, &db) + testSqliteDBRandomSvgInterface(t, &db) } func testSignupLoginRenew(t *testing.T, db CustomerUserRepository) { diff --git a/internal/sqlc/sqlite_queue.go b/internal/sqlc/sqlite_queue.go new file mode 100644 index 0000000..20c17e6 --- /dev/null +++ b/internal/sqlc/sqlite_queue.go @@ -0,0 +1,93 @@ +package sqlc + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" +) + +const writeBufferSize = 100 + +type SqlcGeneric func(*Queries, context.Context, any) error + +type WriteTx struct { + ErrChan chan error + Query SqlcGeneric + Args interface{} +} + +type Queue struct { + Queries *Queries + Db *sql.DB + WriteQueue chan WriteTx + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc +} + +func NewQueue(sqlDb *sql.DB, ctx context.Context) (*Queue, error) { + ctx, cancel := context.WithCancel(context.Background()) + sqldb := &Queue{ + Queries: New(sqlDb), + Db: sqlDb, + WriteQueue: make(chan WriteTx, writeBufferSize), + ctx: ctx, + cancel: cancel, + } + + return sqldb, nil +} + +func (d *Queue) Start() { + d.wg.Add(1) + defer d.wg.Done() + go func() { + for { + select { + case <-d.ctx.Done(): + return + case writeTx := <-d.WriteQueue: + err := writeTx.Query(d.Queries, d.ctx, writeTx.Args) + writeTx.ErrChan <- err + } + } + }() +} + +func (d *Queue) Stop() error { + d.cancel() + d.wg.Wait() + close(d.WriteQueue) + return d.Db.Close() +} + +func (d *Queue) EnqueueWriteTx(queryFunc SqlcGeneric, args any) error { + select { + case <-d.ctx.Done(): + return errors.New("database is shutting down") + default: + } + + errChan := make(chan error, 1) + writeTx := WriteTx{ + Query: queryFunc, + Args: args, + ErrChan: errChan, + } + d.WriteQueue <- writeTx + return <-errChan +} + +func OpenSqliteDb(dbPath string) (*sql.DB, error) { + sqliteDb, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := sqliteDb.Ping(); err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + return sqliteDb, nil +}