Account for missing options database/table in cockroach store
This commit is contained in:
parent
7a2dea6cc2
commit
08a2de1ef5
@ -23,6 +23,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
re = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||||
|
|
||||||
statements = map[string]string{
|
statements = map[string]string{
|
||||||
"list": "SELECT key, value, expiry FROM %s.%s;",
|
"list": "SELECT key, value, expiry FROM %s.%s;",
|
||||||
"read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;",
|
"read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;",
|
||||||
@ -42,14 +44,33 @@ type sqlStore struct {
|
|||||||
databases map[string]bool
|
databases map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlStore) createDB(database, table string) {
|
func (s *sqlStore) getDB(database, table string) (string, string) {
|
||||||
if len(database) == 0 {
|
if len(database) == 0 {
|
||||||
database = s.options.Database
|
if len(s.options.Database) > 0 {
|
||||||
|
database = s.options.Database
|
||||||
|
} else {
|
||||||
|
database = DefaultDatabase
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(table) == 0 {
|
if len(table) == 0 {
|
||||||
table = s.options.Table
|
if len(s.options.Table) > 0 {
|
||||||
|
table = s.options.Table
|
||||||
|
} else {
|
||||||
|
database = DefaultTable
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// store.namespace must only contain letters, numbers and underscores
|
||||||
|
database = re.ReplaceAllString(database, "_")
|
||||||
|
table = re.ReplaceAllString(table, "_")
|
||||||
|
|
||||||
|
return database, table
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sqlStore) createDB(database, table string) {
|
||||||
|
database, table = s.getDB(database, table)
|
||||||
|
|
||||||
s.Lock()
|
s.Lock()
|
||||||
_, ok := s.databases[database+":"+table]
|
_, ok := s.databases[database+":"+table]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -97,28 +118,10 @@ func (s *sqlStore) configure() error {
|
|||||||
s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
|
s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
|
||||||
}
|
}
|
||||||
|
|
||||||
database := s.options.Database
|
|
||||||
if len(database) == 0 {
|
|
||||||
s.options.Database = DefaultDatabase
|
|
||||||
}
|
|
||||||
|
|
||||||
table := s.options.Table
|
|
||||||
if len(table) == 0 {
|
|
||||||
s.options.Table = DefaultTable
|
|
||||||
}
|
|
||||||
|
|
||||||
// store.namespace must only contain letters, numbers and underscores
|
|
||||||
reg, err := regexp.Compile("[^a-zA-Z0-9]+")
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("error compiling regex for namespace")
|
|
||||||
}
|
|
||||||
database = reg.ReplaceAllString(database, "_")
|
|
||||||
table = reg.ReplaceAllString(table, "_")
|
|
||||||
|
|
||||||
source := s.options.Nodes[0]
|
source := s.options.Nodes[0]
|
||||||
// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
|
// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
|
||||||
// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
|
// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
|
||||||
_, err = url.Parse(source)
|
_, err := url.Parse(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.Contains(source, " ") {
|
if !strings.Contains(source, " ") {
|
||||||
source = fmt.Sprintf("host=%s", source)
|
source = fmt.Sprintf("host=%s", source)
|
||||||
@ -142,8 +145,11 @@ func (s *sqlStore) configure() error {
|
|||||||
// save the values
|
// save the values
|
||||||
s.db = db
|
s.db = db
|
||||||
|
|
||||||
|
// get DB
|
||||||
|
database, table := s.getDB(s.options.Database, s.options.Table)
|
||||||
|
|
||||||
// initialise the database
|
// initialise the database
|
||||||
return s.initDB(s.options.Database, s.options.Table)
|
return s.initDB(database, table)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
|
func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
|
||||||
@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("unsupported statement")
|
return nil, errors.New("unsupported statement")
|
||||||
}
|
}
|
||||||
if len(database) == 0 {
|
|
||||||
database = s.options.Database
|
// get DB
|
||||||
}
|
database, table = s.getDB(database, table)
|
||||||
if len(table) == 0 {
|
|
||||||
table = s.options.Table
|
|
||||||
}
|
|
||||||
|
|
||||||
q := fmt.Sprintf(st, database, table)
|
q := fmt.Sprintf(st, database, table)
|
||||||
stmt, err := s.db.Prepare(q)
|
stmt, err := s.db.Prepare(q)
|
||||||
@ -425,7 +428,7 @@ func (s *sqlStore) String() string {
|
|||||||
func NewStore(opts ...store.Option) store.Store {
|
func NewStore(opts ...store.Option) store.Store {
|
||||||
options := store.Options{
|
options := store.Options{
|
||||||
Database: DefaultDatabase,
|
Database: DefaultDatabase,
|
||||||
Table: DefaultTable,
|
Table: DefaultTable,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
|
Loading…
Reference in New Issue
Block a user