158 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			158 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package database
 | |
| 
 | |
| import (
 | |
| 	"crypto/tls"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/url"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	ErrInvalidDSNAddr      = errors.New("invalid dsn addr")
 | |
| 	ErrInvalidDSNUnescaped = errors.New("dsn must be escaped")
 | |
| 	ErrInvalidDSNNoSlash   = errors.New("dsn must contains slash")
 | |
| )
 | |
| 
 | |
| type Config struct {
 | |
| 	TLSConfig *tls.Config
 | |
| 	Username  string
 | |
| 	Password  string
 | |
| 	Scheme    string
 | |
| 	Host      string
 | |
| 	Port      string
 | |
| 	Database  string
 | |
| 	Params    []string
 | |
| }
 | |
| 
 | |
| func (cfg *Config) FormatDSN() string {
 | |
| 	var s strings.Builder
 | |
| 
 | |
| 	if len(cfg.Scheme) > 0 {
 | |
| 		s.WriteString(cfg.Scheme + "://")
 | |
| 	}
 | |
| 	// [username[:password]@]
 | |
| 	if len(cfg.Username) > 0 {
 | |
| 		s.WriteString(cfg.Username)
 | |
| 		if len(cfg.Password) > 0 {
 | |
| 			s.WriteByte(':')
 | |
| 			s.WriteString(url.PathEscape(cfg.Password))
 | |
| 		}
 | |
| 		s.WriteByte('@')
 | |
| 	}
 | |
| 
 | |
| 	// [host:port]
 | |
| 	if len(cfg.Host) > 0 {
 | |
| 		s.WriteString(cfg.Host)
 | |
| 		if len(cfg.Port) > 0 {
 | |
| 			s.WriteByte(':')
 | |
| 			s.WriteString(cfg.Port)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// /dbname
 | |
| 	s.WriteByte('/')
 | |
| 	s.WriteString(url.PathEscape(cfg.Database))
 | |
| 
 | |
| 	for i := 0; i < len(cfg.Params); i += 2 {
 | |
| 		if i == 0 {
 | |
| 			s.WriteString("?")
 | |
| 		} else {
 | |
| 			s.WriteString("&")
 | |
| 		}
 | |
| 		s.WriteString(cfg.Params[i])
 | |
| 		s.WriteString("=")
 | |
| 		s.WriteString(cfg.Params[i+1])
 | |
| 	}
 | |
| 
 | |
| 	return s.String()
 | |
| }
 | |
| 
 | |
| func ParseDSN(dsn string) (*Config, error) {
 | |
| 	cfg := &Config{}
 | |
| 
 | |
| 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
 | |
| 	// Find last '/' that goes before dbname
 | |
| 	foundSlash := false
 | |
| 	for i := len(dsn) - 1; i >= 0; i-- {
 | |
| 		if dsn[i] == '/' {
 | |
| 			foundSlash = true
 | |
| 			var j, k int
 | |
| 
 | |
| 			// left part is empty if i <= 0
 | |
| 			if i > 0 {
 | |
| 				// Find the first ':' in dsn
 | |
| 				for j = i; j >= 0; j-- {
 | |
| 					if dsn[j] == ':' {
 | |
| 						cfg.Scheme = dsn[0:j]
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				// [username[:password]@][host]
 | |
| 				// Find the last '@' in dsn[:i]
 | |
| 				for j = i; j >= 0; j-- {
 | |
| 					if dsn[j] == '@' {
 | |
| 						// username[:password]
 | |
| 						// Find the second ':' in dsn[:j]
 | |
| 						for k = 0; k < j; k++ {
 | |
| 							if dsn[k] == ':' {
 | |
| 								if cfg.Scheme == dsn[:k] {
 | |
| 									continue
 | |
| 								}
 | |
| 								var err error
 | |
| 								cfg.Password, err = url.PathUnescape(dsn[k+1 : j])
 | |
| 								if err != nil {
 | |
| 									return nil, err
 | |
| 								}
 | |
| 								break
 | |
| 							}
 | |
| 						}
 | |
| 						cfg.Username = dsn[len(cfg.Scheme)+3 : k]
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				for k = j + 1; k < i; k++ {
 | |
| 					if dsn[k] == ':' {
 | |
| 						cfg.Host = dsn[j+1 : k]
 | |
| 						cfg.Port = dsn[k+1 : i]
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 			}
 | |
| 
 | |
| 			// dbname[?param1=value1&...¶mN=valueN]
 | |
| 			// Find the first '?' in dsn[i+1:]
 | |
| 			for j = i + 1; j < len(dsn); j++ {
 | |
| 				if dsn[j] == '?' {
 | |
| 					parts := strings.Split(dsn[j+1:], "&")
 | |
| 					cfg.Params = make([]string, 0, len(parts)*2)
 | |
| 					for _, p := range parts {
 | |
| 						k, v, found := strings.Cut(p, "=")
 | |
| 						if !found {
 | |
| 							continue
 | |
| 						}
 | |
| 						cfg.Params = append(cfg.Params, k, v)
 | |
| 					}
 | |
| 
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 			var err error
 | |
| 			dbname := dsn[i+1 : j]
 | |
| 			if cfg.Database, err = url.PathUnescape(dbname); err != nil {
 | |
| 				return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
 | |
| 			}
 | |
| 
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if !foundSlash && len(dsn) > 0 {
 | |
| 		return nil, ErrInvalidDSNNoSlash
 | |
| 	}
 | |
| 
 | |
| 	return cfg, nil
 | |
| }
 |