From 4073a8eacaed75a18de6ad0007eddf3e6bb83755 Mon Sep 17 00:00:00 2001 From: Adora Laura Kalb Date: Thu, 9 May 2024 15:28:04 +0200 Subject: [PATCH] rework DB initialization for coming unit tests --- cmd/go-urlsh/go-urlsh.go | 10 +++------- internal/db/initialize.go | 42 ++++++++++++++++++++------------------- internal/db/migrations.go | 6 +++--- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/cmd/go-urlsh/go-urlsh.go b/cmd/go-urlsh/go-urlsh.go index b03f6ef..326e883 100644 --- a/cmd/go-urlsh/go-urlsh.go +++ b/cmd/go-urlsh/go-urlsh.go @@ -3,23 +3,19 @@ package app import ( "fmt" - "log" - "code.lila.network/adoralaura/go-urlsh/internal/app" "code.lila.network/adoralaura/go-urlsh/internal/db" + "code.lila.network/adoralaura/go-urlsh/models" ) func Run() error { - err := db.InitializeDB() - if err != nil { - log.Fatalln(err) - } + models.DB = db.InitializeDB() go app.CleanupLogins() go app.CleanupLoginsCronJob() - err = app.SetupFiber() + err := app.SetupFiber() if err != nil { return fmt.Errorf("couldn't start webserver: %v", err.Error()) } diff --git a/internal/db/initialize.go b/internal/db/initialize.go index b039c17..9b1fd69 100644 --- a/internal/db/initialize.go +++ b/internal/db/initialize.go @@ -3,57 +3,59 @@ package db import ( "context" "database/sql" - "fmt" "os" + "log" + "code.lila.network/adoralaura/go-urlsh/models" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" ) -func InitializeDB() error { +func InitializeDB() *bun.DB { sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(os.Getenv("DATABASE_URL")))) - models.DB = bun.NewDB(sqldb, pgdialect.New()) + db := bun.NewDB(sqldb, pgdialect.New()) - _, err := models.DB.NewCreateTable().IfNotExists().Model((*models.Link)(nil)).Exec(context.Background()) + _, err := db.NewCreateTable().IfNotExists().Model((*models.Link)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.User)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.User)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.Session)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.Session)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.ApiKey)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.ApiKey)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.MFALoginTransaction)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.MFALoginTransaction)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.MFAConfig)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.MFAConfig)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - _, err = models.DB.NewCreateTable().IfNotExists().Model((*models.MFAScratchCode)(nil)).Exec(context.Background()) + _, err = db.NewCreateTable().IfNotExists().Model((*models.MFAScratchCode)(nil)).Exec(context.Background()) if err != nil { - return fmt.Errorf("[DB] couldn't create database: [%w]", err) + log.Panicf("[DB] couldn't create database: [%w]", err) } - err = doMigrations() + err = doMigrations(db) if err != nil { - return fmt.Errorf("[DB] Error during Migrations: [%w]", err) + log.Panicf("[DB] Error during Migrations: [%w]", err) } - return nil + + return db } diff --git a/internal/db/migrations.go b/internal/db/migrations.go index 284e534..560c58e 100644 --- a/internal/db/migrations.go +++ b/internal/db/migrations.go @@ -5,13 +5,13 @@ import ( "log" "code.lila.network/adoralaura/go-urlsh/migrations" - "code.lila.network/adoralaura/go-urlsh/models" + "github.com/uptrace/bun" "github.com/uptrace/bun/migrate" ) -func doMigrations() error { +func doMigrations(db *bun.DB) error { ctx := context.Background() - migrator := migrate.NewMigrator(models.DB, migrations.Migrations) + migrator := migrate.NewMigrator(db, migrations.Migrations) migrator.Init(ctx)