pkgdash/internal/database/database.go

254 lines
5.4 KiB
Go
Raw Normal View History

package database
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
mpgx "github.com/golang-migrate/migrate/v4/database/pgx"
msqlite "github.com/golang-migrate/migrate/v4/database/sqlite"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"go.unistack.org/micro/v3/logger"
appconfig "go.unistack.org/pkgdash/internal/config"
_ "modernc.org/sqlite"
)
func ParseDSN(cfg *appconfig.DatabaseConfig) error {
var err error
u, err := url.Parse(cfg.DSN)
if err != nil {
return err
}
values := u.Query()
var value string
if value = values.Get("conn_max"); value != "" {
values.Del("conn_max")
maxOpenConns, err := strconv.Atoi(value)
if err != nil {
return err
}
cfg.MaxOpenConns = maxOpenConns
cfg.MaxIdleConns = maxOpenConns / 2
}
if value = values.Get("conn_maxidle"); value != "" {
values.Del("conn_maxidle")
maxIdleConns, err := strconv.Atoi(value)
if err != nil {
return err
}
cfg.MaxIdleConns = maxIdleConns
}
if value = values.Get("conn_lifetime"); value != "" {
values.Del("conn_lifetime")
connMaxLifetime, err := time.ParseDuration(value)
if err != nil {
return err
}
cfg.ConnMaxLifetime = connMaxLifetime
}
if value = values.Get("conn_maxidletime"); value != "" {
values.Del("conn_maxidletime")
connMaxIdleTime, err := time.ParseDuration(value)
if err != nil {
return err
}
cfg.ConnMaxIdleTime = connMaxIdleTime
}
if mtype := values.Get("migrate"); mtype != "" {
values.Del("migrate")
cfg.Migrate = mtype
}
switch u.Scheme {
case "postgres", "pgsql", "postgresql":
u.Scheme = "postgres"
case "sqlite", "sqlite3":
u.Scheme = "sqlite"
default:
return fmt.Errorf("unknown database %s", u.Scheme)
}
cfg.Type = u.Scheme
u.RawQuery = values.Encode()
cfg.ConnStr = u.String()
return nil
}
func connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) {
var db *sqlx.DB
var err error
log.Info(ctx, "connect to %s", cfg.Type)
switch cfg.Type {
case "postgres", "pgsql", "postgresql":
db, err = connectPostgres(ctx, cfg.ConnStr)
cfg.Type = "postgres"
case "sqlite", "sqlite3":
db, err = connectSqlite(ctx, cfg.ConnStr)
cfg.Type = "sqlite"
default:
return nil, fmt.Errorf("unknown database type %s", cfg.Type)
}
if err != nil {
return nil, err
}
return db, nil
}
func Connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) {
db, err := connect(ctx, cfg, log)
if err != nil {
return nil, err
}
m, err := migratePrepare(ctx, db, log, cfg.Type)
if err != nil {
return nil, err
}
switch cfg.Migrate {
case "":
break
case "up":
log.Info(ctx, "migrate up")
err = m.Up()
case "down":
log.Info(ctx, "migrate down")
err = m.Down()
case "seed":
log.Info(ctx, "migrate seed")
if err = m.Drop(); err == nil {
err = m.Up()
}
default:
log.Info(ctx, "migrate version")
v, verr := strconv.ParseUint(cfg.Type, 10, 64)
if verr != nil {
return nil, err
}
err = m.Migrate(uint(v))
}
if err == nil || err == migrate.ErrNoChange {
srcerr, dberr := m.Close()
if srcerr != nil {
err = srcerr
} else if dberr != nil {
err = dberr
} else {
err = nil
}
}
if err == nil {
db, err = connect(ctx, cfg, log)
}
if err != nil {
return nil, err
}
db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetMaxOpenConns(cfg.MaxOpenConns)
return db, nil
}
func connectSqlite(ctx context.Context, connstr string) (*sqlx.DB, error) {
if !strings.Contains(connstr, ":memory:") {
return sqlx.ConnectContext(ctx, "sqlite", "file:"+connstr[9:])
}
return sqlx.ConnectContext(ctx, "sqlite", connstr[9:])
}
func connectPostgres(ctx context.Context, connstr string) (*sqlx.DB, error) {
// parse connection string
dbConf, err := pgx.ParseConfig(connstr)
if err != nil {
return nil, err
}
// needed for pgbouncer
dbConf.RuntimeParams = map[string]string{
"standard_conforming_strings": "on",
"application_name": "authn",
}
// may be needed for pbbouncer, needs to check
// dbConf.PreferSimpleProtocol = true
// register pgx conn
connStr := stdlib.RegisterConnConfig(dbConf)
db, err := sqlx.ConnectContext(ctx, "pgx", connStr)
if err != nil {
return nil, err
}
return db, nil
}
func migratePrepare(ctx context.Context, db *sqlx.DB, log logger.Logger, dbtype string) (*migrate.Migrate, error) {
var driver database.Driver
var err error
switch dbtype {
case "postgres":
driver, err = mpgx.WithInstance(db.DB, &mpgx.Config{
DatabaseName: "pkgdash",
MigrationsTable: "schema_migrations",
})
case "sqlite":
driver, err = msqlite.WithInstance(db.DB, &msqlite.Config{
DatabaseName: "pkgdash",
MigrationsTable: "schema_migrations",
})
}
if err != nil {
return nil, err
}
source, err := iofs.New(assets, "migrations/"+dbtype)
if err != nil {
return nil, err
}
m, err := migrate.NewWithInstance("fs", source, "apigw", driver)
if err != nil {
return nil, err
}
m.Log = &mLog{ctx: ctx, l: log}
return m, nil
}
type mLog struct {
ctx context.Context
l logger.Logger
}
func (l *mLog) Verbose() bool {
return l.l.V(logger.DebugLevel)
}
func (l *mLog) Printf(format string, v ...interface{}) {
l.l.Info(l.ctx, format, v...)
}