158 lines
3.1 KiB
Go
158 lines
3.1 KiB
Go
package database
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidDSNAddr = errors.New("invalid dsn addr")
|
|
ErrInvalidDSNUnescaped = errors.New("dsn must be escaped")
|
|
ErrInvalidDSNNoSlash = errors.New("dsn must contains slash")
|
|
)
|
|
|
|
type Config struct {
|
|
TLSConfig *tls.Config
|
|
Username string
|
|
Password string
|
|
Scheme string
|
|
Host string
|
|
Port string
|
|
Database string
|
|
Params []string
|
|
}
|
|
|
|
func (cfg *Config) FormatDSN() string {
|
|
var s strings.Builder
|
|
|
|
if len(cfg.Scheme) > 0 {
|
|
s.WriteString(cfg.Scheme + "://")
|
|
}
|
|
// [username[:password]@]
|
|
if len(cfg.Username) > 0 {
|
|
s.WriteString(cfg.Username)
|
|
if len(cfg.Password) > 0 {
|
|
s.WriteByte(':')
|
|
s.WriteString(url.PathEscape(cfg.Password))
|
|
}
|
|
s.WriteByte('@')
|
|
}
|
|
|
|
// [host:port]
|
|
if len(cfg.Host) > 0 {
|
|
s.WriteString(cfg.Host)
|
|
if len(cfg.Port) > 0 {
|
|
s.WriteByte(':')
|
|
s.WriteString(cfg.Port)
|
|
}
|
|
}
|
|
|
|
// /dbname
|
|
s.WriteByte('/')
|
|
s.WriteString(url.PathEscape(cfg.Database))
|
|
|
|
for i := 0; i < len(cfg.Params); i += 2 {
|
|
if i == 0 {
|
|
s.WriteString("?")
|
|
} else {
|
|
s.WriteString("&")
|
|
}
|
|
s.WriteString(cfg.Params[i])
|
|
s.WriteString("=")
|
|
s.WriteString(cfg.Params[i+1])
|
|
}
|
|
|
|
return s.String()
|
|
}
|
|
|
|
func ParseDSN(dsn string) (*Config, error) {
|
|
cfg := &Config{}
|
|
|
|
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
|
// Find last '/' that goes before dbname
|
|
foundSlash := false
|
|
for i := len(dsn) - 1; i >= 0; i-- {
|
|
if dsn[i] == '/' {
|
|
foundSlash = true
|
|
var j, k int
|
|
|
|
// left part is empty if i <= 0
|
|
if i > 0 {
|
|
// Find the first ':' in dsn
|
|
for j = i; j >= 0; j-- {
|
|
if dsn[j] == ':' {
|
|
cfg.Scheme = dsn[0:j]
|
|
}
|
|
}
|
|
|
|
// [username[:password]@][host]
|
|
// Find the last '@' in dsn[:i]
|
|
for j = i; j >= 0; j-- {
|
|
if dsn[j] == '@' {
|
|
// username[:password]
|
|
// Find the second ':' in dsn[:j]
|
|
for k = 0; k < j; k++ {
|
|
if dsn[k] == ':' {
|
|
if cfg.Scheme == dsn[:k] {
|
|
continue
|
|
}
|
|
var err error
|
|
cfg.Password, err = url.PathUnescape(dsn[k+1 : j])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
break
|
|
}
|
|
}
|
|
cfg.Username = dsn[len(cfg.Scheme)+3 : k]
|
|
break
|
|
}
|
|
}
|
|
|
|
for k = j + 1; k < i; k++ {
|
|
if dsn[k] == ':' {
|
|
cfg.Host = dsn[j+1 : k]
|
|
cfg.Port = dsn[k+1 : i]
|
|
break
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// dbname[?param1=value1&...¶mN=valueN]
|
|
// Find the first '?' in dsn[i+1:]
|
|
for j = i + 1; j < len(dsn); j++ {
|
|
if dsn[j] == '?' {
|
|
parts := strings.Split(dsn[j+1:], "&")
|
|
cfg.Params = make([]string, 0, len(parts)*2)
|
|
for _, p := range parts {
|
|
k, v, found := strings.Cut(p, "=")
|
|
if !found {
|
|
continue
|
|
}
|
|
cfg.Params = append(cfg.Params, k, v)
|
|
}
|
|
|
|
break
|
|
}
|
|
}
|
|
var err error
|
|
dbname := dsn[i+1 : j]
|
|
if cfg.Database, err = url.PathUnescape(dbname); err != nil {
|
|
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
|
|
}
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
if !foundSlash && len(dsn) > 0 {
|
|
return nil, ErrInvalidDSNNoSlash
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|