package database import ( "crypto/tls" "errors" "fmt" "net/url" "strings" ) var ( ErrInvalidDSNAddr = errors.New("invalid dsn addr") ErrInvalidDSNUnescaped = errors.New("dsn must be escaped") ErrInvalidDSNNoSlash = errors.New("dsn must contains slash") ) type Config struct { TLSConfig *tls.Config Username string Password string Scheme string Host string Port string Database string Params []string } func (cfg *Config) FormatDSN() string { var s strings.Builder if len(cfg.Scheme) > 0 { s.WriteString(cfg.Scheme + "://") } // [username[:password]@] if len(cfg.Username) > 0 { s.WriteString(cfg.Username) if len(cfg.Password) > 0 { s.WriteByte(':') s.WriteString(url.PathEscape(cfg.Password)) } s.WriteByte('@') } // [host:port] if len(cfg.Host) > 0 { s.WriteString(cfg.Host) if len(cfg.Port) > 0 { s.WriteByte(':') s.WriteString(cfg.Port) } } // /dbname s.WriteByte('/') s.WriteString(url.PathEscape(cfg.Database)) for i := 0; i < len(cfg.Params); i += 2 { if i == 0 { s.WriteString("?") } else { s.WriteString("&") } s.WriteString(cfg.Params[i]) s.WriteString("=") s.WriteString(cfg.Params[i+1]) } return s.String() } func ParseDSN(dsn string) (*Config, error) { cfg := &Config{} // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find last '/' that goes before dbname foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { foundSlash = true var j, k int // left part is empty if i <= 0 if i > 0 { // Find the first ':' in dsn for j = i; j >= 0; j-- { if dsn[j] == ':' { cfg.Scheme = dsn[0:j] } } // [username[:password]@][host] // Find the last '@' in dsn[:i] for j = i; j >= 0; j-- { if dsn[j] == '@' { // username[:password] // Find the second ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { if cfg.Scheme == dsn[:k] { continue } var err error cfg.Password, err = url.PathUnescape(dsn[k+1 : j]) if err != nil { return nil, err } break } } cfg.Username = dsn[len(cfg.Scheme)+3 : k] break } } for k = j + 1; k < i; k++ { if dsn[k] == ':' { cfg.Host = dsn[j+1 : k] cfg.Port = dsn[k+1 : i] break } } } // dbname[?param1=value1&...¶mN=valueN] // Find the first '?' in dsn[i+1:] for j = i + 1; j < len(dsn); j++ { if dsn[j] == '?' { parts := strings.Split(dsn[j+1:], "&") cfg.Params = make([]string, 0, len(parts)*2) for _, p := range parts { k, v, found := strings.Cut(p, "=") if !found { continue } cfg.Params = append(cfg.Params, k, v) } break } } var err error dbname := dsn[i+1 : j] if cfg.Database, err = url.PathUnescape(dbname); err != nil { return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err) } break } } if !foundSlash && len(dsn) > 0 { return nil, ErrInvalidDSNNoSlash } return cfg, nil }