package database import ( "context" "fmt" "net/url" "strconv" "strings" "time" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" mpgx "github.com/golang-migrate/migrate/v4/database/pgx" msqlite "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/golang-migrate/migrate/v4/source/iofs" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" "go.unistack.org/micro/v3/logger" appconfig "go.unistack.org/pkgdash/internal/config" _ "modernc.org/sqlite" ) func ParseDSN(cfg *appconfig.DatabaseConfig) error { var err error u, err := url.Parse(cfg.DSN) if err != nil { return err } values := u.Query() var value string if value = values.Get("conn_max"); value != "" { values.Del("conn_max") maxOpenConns, err := strconv.Atoi(value) if err != nil { return err } cfg.MaxOpenConns = maxOpenConns cfg.MaxIdleConns = maxOpenConns / 2 } if value = values.Get("conn_maxidle"); value != "" { values.Del("conn_maxidle") maxIdleConns, err := strconv.Atoi(value) if err != nil { return err } cfg.MaxIdleConns = maxIdleConns } if value = values.Get("conn_lifetime"); value != "" { values.Del("conn_lifetime") connMaxLifetime, err := time.ParseDuration(value) if err != nil { return err } cfg.ConnMaxLifetime = connMaxLifetime } if value = values.Get("conn_maxidletime"); value != "" { values.Del("conn_maxidletime") connMaxIdleTime, err := time.ParseDuration(value) if err != nil { return err } cfg.ConnMaxIdleTime = connMaxIdleTime } if mtype := values.Get("migrate"); mtype != "" { values.Del("migrate") cfg.Migrate = mtype } switch u.Scheme { case "postgres", "pgsql", "postgresql": u.Scheme = "postgres" case "sqlite", "sqlite3": u.Scheme = "sqlite" default: return fmt.Errorf("unknown database %s", u.Scheme) } cfg.Type = u.Scheme u.RawQuery = values.Encode() cfg.ConnStr = u.String() return nil } func connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) { var db *sqlx.DB var err error log.Info(ctx, "connect to %s", cfg.Type) switch cfg.Type { case "postgres", "pgsql", "postgresql": db, err = connectPostgres(ctx, cfg.ConnStr) cfg.Type = "postgres" case "sqlite", "sqlite3": db, err = connectSqlite(ctx, cfg.ConnStr) cfg.Type = "sqlite" default: return nil, fmt.Errorf("unknown database type %s", cfg.Type) } if err != nil { return nil, err } return db, nil } func Connect(ctx context.Context, cfg *appconfig.DatabaseConfig, log logger.Logger) (*sqlx.DB, error) { db, err := connect(ctx, cfg, log) if err != nil { return nil, err } m, err := migratePrepare(ctx, db, log, cfg.Type) if err != nil { return nil, err } switch cfg.Migrate { case "": break case "up": log.Info(ctx, "migrate up") err = m.Up() case "down": log.Info(ctx, "migrate down") err = m.Down() case "seed": log.Info(ctx, "migrate seed") if err = m.Drop(); err == nil { err = m.Up() } default: log.Info(ctx, "migrate version") v, verr := strconv.ParseUint(cfg.Type, 10, 64) if verr != nil { return nil, err } err = m.Migrate(uint(v)) } if err == nil || err == migrate.ErrNoChange { srcerr, dberr := m.Close() if srcerr != nil { err = srcerr } else if dberr != nil { err = dberr } else { err = nil } } if err == nil { db, err = connect(ctx, cfg, log) } if err != nil { return nil, err } db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime) db.SetConnMaxLifetime(cfg.ConnMaxLifetime) db.SetMaxIdleConns(cfg.MaxIdleConns) db.SetMaxOpenConns(cfg.MaxOpenConns) return db, nil } func connectSqlite(ctx context.Context, connstr string) (*sqlx.DB, error) { if !strings.Contains(connstr, ":memory:") { return sqlx.ConnectContext(ctx, "sqlite", "file:"+connstr[9:]) } return sqlx.ConnectContext(ctx, "sqlite", connstr[9:]) } func connectPostgres(ctx context.Context, connstr string) (*sqlx.DB, error) { // parse connection string dbConf, err := pgx.ParseConfig(connstr) if err != nil { return nil, err } // needed for pgbouncer dbConf.RuntimeParams = map[string]string{ "standard_conforming_strings": "on", "application_name": "authn", } // may be needed for pbbouncer, needs to check // dbConf.PreferSimpleProtocol = true // register pgx conn connStr := stdlib.RegisterConnConfig(dbConf) db, err := sqlx.ConnectContext(ctx, "pgx", connStr) if err != nil { return nil, err } return db, nil } func migratePrepare(ctx context.Context, db *sqlx.DB, log logger.Logger, dbtype string) (*migrate.Migrate, error) { var driver database.Driver var err error switch dbtype { case "postgres": driver, err = mpgx.WithInstance(db.DB, &mpgx.Config{ DatabaseName: "pkgdash", MigrationsTable: "schema_migrations", }) case "sqlite": driver, err = msqlite.WithInstance(db.DB, &msqlite.Config{ DatabaseName: "pkgdash", MigrationsTable: "schema_migrations", }) } if err != nil { return nil, err } source, err := iofs.New(assets, "migrations/"+dbtype) if err != nil { return nil, err } m, err := migrate.NewWithInstance("fs", source, "apigw", driver) if err != nil { return nil, err } m.Log = &mLog{ctx: ctx, l: log} return m, nil } type mLog struct { ctx context.Context l logger.Logger } func (l *mLog) Verbose() bool { return l.l.V(logger.DebugLevel) } func (l *mLog) Printf(format string, v ...interface{}) { l.l.Info(l.ctx, format, v...) }