diff --git a/database/dsn.go b/database/dsn.go index 0f0a2e96..8e6627a8 100644 --- a/database/dsn.go +++ b/database/dsn.go @@ -15,7 +15,6 @@ var ( ) type Config struct { - Params map[string]string TLSConfig *tls.Config Username string Password string @@ -23,6 +22,50 @@ type Config struct { Host string Port string Database string + Params []string +} + +func (cfg *Config) FormatDSN() string { + var s strings.Builder + + if len(cfg.Scheme) > 0 { + s.WriteString(cfg.Scheme + "://") + } + // [username[:password]@] + if len(cfg.Username) > 0 { + s.WriteString(cfg.Username) + if len(cfg.Password) > 0 { + s.WriteByte(':') + s.WriteString(url.PathEscape(cfg.Password)) + } + s.WriteByte('@') + } + + // [host:port] + if len(cfg.Host) > 0 { + s.WriteString(cfg.Host) + if len(cfg.Port) > 0 { + s.WriteByte(':') + s.WriteString(cfg.Port) + } + } + + // /dbname + s.WriteByte('/') + s.WriteString(url.PathEscape(cfg.Database)) + + for i := 0; i < len(cfg.Params); i += 2 { + if i == 0 { + s.WriteString("?") + } else { + s.WriteString("&") + } + s.WriteString(cfg.Params[i]) + s.WriteString("=") + s.WriteString(cfg.Params[i+1]) + } + + return s.String() } func ParseDSN(dsn string) (*Config, error) { @@ -84,13 +127,13 @@ func ParseDSN(dsn string) (*Config, error) { for j = i + 1; j < len(dsn); j++ { if dsn[j] == '?' { parts := strings.Split(dsn[j+1:], "&") - cfg.Params = make(map[string]string, len(parts)) + cfg.Params = make([]string, 0, len(parts)*2) for _, p := range parts { k, v, found := strings.Cut(p, "=") if !found { continue } - cfg.Params[k] = v + cfg.Params = append(cfg.Params, k, v) } break diff --git a/database/dsn_test.go b/database/dsn_test.go index 04c07781..0812d8a4 100644 --- a/database/dsn_test.go +++ b/database/dsn_test.go @@ -1,6 +1,7 @@ package database import ( + "net/url" "testing" ) @@ -13,3 +14,18 @@ func TestParseDSN(t *testing.T) { t.Fatalf("parsing error") } } + +func TestFormatDSN(t *testing.T) { + src := "postgres://username:p@ssword#@host:12345/dbname?key1=val2&key2=val2" + cfg, err := ParseDSN(src) + if err != nil { + t.Fatal(err) + } + dst, err := url.PathUnescape(cfg.FormatDSN()) + if err != nil { + t.Fatal(err) + } + if src != dst { + t.Fatalf("\n%s\n%s", src, dst) + } +}