254 lines
5.4 KiB
Go
254 lines
5.4 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database"
|
|
mpgx "github.com/golang-migrate/migrate/v4/database/pgx"
|
|
msqlite "github.com/golang-migrate/migrate/v4/database/sqlite"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/stdlib"
|
|
"github.com/jmoiron/sqlx"
|
|
"go.unistack.org/micro/v3/logger"
|
|
appconfig "go.unistack.org/pkgdash/internal/config"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
func ParseDSN(cfg *appconfig.DatabaseConfig) error {
|
|
var err error
|
|
|
|
u, err := url.Parse(cfg.DSN)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
values := u.Query()
|
|
var value string
|
|
|
|
if value = values.Get("conn_max"); value != "" {
|
|
values.Del("conn_max")
|
|
maxOpenConns, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.MaxOpenConns = maxOpenConns
|
|
cfg.MaxIdleConns = maxOpenConns / 2
|
|
}
|
|
|
|
if value = values.Get("conn_maxidle"); value != "" {
|
|
values.Del("conn_maxidle")
|
|
maxIdleConns, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.MaxIdleConns = maxIdleConns
|
|
}
|
|
|
|
if value = values.Get("conn_lifetime"); value != "" {
|
|
values.Del("conn_lifetime")
|
|
connMaxLifetime, err := time.ParseDuration(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.ConnMaxLifetime = connMaxLifetime
|
|
}
|
|
|
|
if value = values.Get("conn_maxidletime"); value != "" {
|
|
values.Del("conn_maxidletime")
|
|
connMaxIdleTime, err := time.ParseDuration(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.ConnMaxIdleTime = connMaxIdleTime
|
|
}
|
|
|
|
if mtype := values.Get("migrate"); mtype != "" {
|
|
values.Del("migrate")
|
|
cfg.Migrate = mtype
|
|
}
|
|
|
|
switch u.Scheme {
|
|
case "postgres", "pgsql", "postgresql":
|
|
u.Scheme = "postgres"
|
|
case "sqlite", "sqlite3":
|
|
u.Scheme = "sqlite"
|
|
default:
|
|
return fmt.Errorf("unknown database %s", u.Scheme)
|
|
}
|
|
|
|
cfg.Type = u.Scheme
|
|
u.RawQuery = values.Encode()
|
|
|
|
cfg.ConnStr = u.String()
|
|
|
|
return nil
|
|
}
|
|
|
|
func connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) {
|
|
var db *sqlx.DB
|
|
var err error
|
|
|
|
log.Info(ctx, "connect to %s", cfg.Type)
|
|
switch cfg.Type {
|
|
case "postgres", "pgsql", "postgresql":
|
|
db, err = connectPostgres(ctx, cfg.ConnStr)
|
|
cfg.Type = "postgres"
|
|
case "sqlite", "sqlite3":
|
|
db, err = connectSqlite(ctx, cfg.ConnStr)
|
|
cfg.Type = "sqlite"
|
|
default:
|
|
return nil, fmt.Errorf("unknown database type %s", cfg.Type)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func Connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) {
|
|
db, err := connect(ctx, cfg, log)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m, err := migratePrepare(ctx, db, log, cfg.Type)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch cfg.Migrate {
|
|
case "":
|
|
break
|
|
case "up":
|
|
log.Info(ctx, "migrate up")
|
|
err = m.Up()
|
|
case "down":
|
|
log.Info(ctx, "migrate down")
|
|
err = m.Down()
|
|
case "seed":
|
|
log.Info(ctx, "migrate seed")
|
|
if err = m.Drop(); err == nil {
|
|
err = m.Up()
|
|
}
|
|
default:
|
|
log.Info(ctx, "migrate version")
|
|
v, verr := strconv.ParseUint(cfg.Type, 10, 64)
|
|
if verr != nil {
|
|
return nil, err
|
|
}
|
|
err = m.Migrate(uint(v))
|
|
}
|
|
|
|
if err == nil || err == migrate.ErrNoChange {
|
|
srcerr, dberr := m.Close()
|
|
if srcerr != nil {
|
|
err = srcerr
|
|
} else if dberr != nil {
|
|
err = dberr
|
|
} else {
|
|
err = nil
|
|
}
|
|
}
|
|
|
|
if err == nil {
|
|
db, err = connect(ctx, cfg, log)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
|
|
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func connectSqlite(ctx context.Context, connstr string) (*sqlx.DB, error) {
|
|
if !strings.Contains(connstr, ":memory:") {
|
|
return sqlx.ConnectContext(ctx, "sqlite", "file:"+connstr[9:])
|
|
}
|
|
return sqlx.ConnectContext(ctx, "sqlite", connstr[9:])
|
|
}
|
|
|
|
func connectPostgres(ctx context.Context, connstr string) (*sqlx.DB, error) {
|
|
// parse connection string
|
|
dbConf, err := pgx.ParseConfig(connstr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// needed for pgbouncer
|
|
dbConf.RuntimeParams = map[string]string{
|
|
"standard_conforming_strings": "on",
|
|
"application_name": "authn",
|
|
}
|
|
|
|
// may be needed for pbbouncer, needs to check
|
|
// dbConf.PreferSimpleProtocol = true
|
|
// register pgx conn
|
|
connStr := stdlib.RegisterConnConfig(dbConf)
|
|
|
|
db, err := sqlx.ConnectContext(ctx, "pgx", connStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func migratePrepare(ctx context.Context, db *sqlx.DB, log logger.Logger, dbtype string) (*migrate.Migrate, error) {
|
|
var driver database.Driver
|
|
var err error
|
|
|
|
switch dbtype {
|
|
case "postgres":
|
|
driver, err = mpgx.WithInstance(db.DB, &mpgx.Config{
|
|
DatabaseName: "pkgdash",
|
|
MigrationsTable: "schema_migrations",
|
|
})
|
|
case "sqlite":
|
|
driver, err = msqlite.WithInstance(db.DB, &msqlite.Config{
|
|
DatabaseName: "pkgdash",
|
|
MigrationsTable: "schema_migrations",
|
|
})
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
source, err := iofs.New(assets, "migrations/"+dbtype)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m, err := migrate.NewWithInstance("fs", source, "apigw", driver)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m.Log = &mLog{ctx: ctx, l: log}
|
|
return m, nil
|
|
}
|
|
|
|
type mLog struct {
|
|
ctx context.Context
|
|
l logger.Logger
|
|
}
|
|
|
|
func (l *mLog) Verbose() bool {
|
|
return l.l.V(logger.DebugLevel)
|
|
}
|
|
|
|
func (l *mLog) Printf(format string, v ...interface{}) {
|
|
l.l.Info(l.ctx, format, v...)
|
|
}
|