Account for missing options database/table in cockroach store
This commit is contained in:
parent
7a2dea6cc2
commit
08a2de1ef5
@ -23,6 +23,8 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
re = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||
|
||||
statements = map[string]string{
|
||||
"list": "SELECT key, value, expiry FROM %s.%s;",
|
||||
"read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;",
|
||||
@ -42,14 +44,33 @@ type sqlStore struct {
|
||||
databases map[string]bool
|
||||
}
|
||||
|
||||
func (s *sqlStore) createDB(database, table string) {
|
||||
func (s *sqlStore) getDB(database, table string) (string, string) {
|
||||
if len(database) == 0 {
|
||||
database = s.options.Database
|
||||
if len(s.options.Database) > 0 {
|
||||
database = s.options.Database
|
||||
} else {
|
||||
database = DefaultDatabase
|
||||
}
|
||||
}
|
||||
|
||||
if len(table) == 0 {
|
||||
table = s.options.Table
|
||||
if len(s.options.Table) > 0 {
|
||||
table = s.options.Table
|
||||
} else {
|
||||
database = DefaultTable
|
||||
}
|
||||
}
|
||||
|
||||
// store.namespace must only contain letters, numbers and underscores
|
||||
database = re.ReplaceAllString(database, "_")
|
||||
table = re.ReplaceAllString(table, "_")
|
||||
|
||||
return database, table
|
||||
}
|
||||
|
||||
func (s *sqlStore) createDB(database, table string) {
|
||||
database, table = s.getDB(database, table)
|
||||
|
||||
s.Lock()
|
||||
_, ok := s.databases[database+":"+table]
|
||||
if !ok {
|
||||
@ -97,28 +118,10 @@ func (s *sqlStore) configure() error {
|
||||
s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
|
||||
}
|
||||
|
||||
database := s.options.Database
|
||||
if len(database) == 0 {
|
||||
s.options.Database = DefaultDatabase
|
||||
}
|
||||
|
||||
table := s.options.Table
|
||||
if len(table) == 0 {
|
||||
s.options.Table = DefaultTable
|
||||
}
|
||||
|
||||
// store.namespace must only contain letters, numbers and underscores
|
||||
reg, err := regexp.Compile("[^a-zA-Z0-9]+")
|
||||
if err != nil {
|
||||
return errors.New("error compiling regex for namespace")
|
||||
}
|
||||
database = reg.ReplaceAllString(database, "_")
|
||||
table = reg.ReplaceAllString(table, "_")
|
||||
|
||||
source := s.options.Nodes[0]
|
||||
// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
|
||||
// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
|
||||
_, err = url.Parse(source)
|
||||
_, err := url.Parse(source)
|
||||
if err != nil {
|
||||
if !strings.Contains(source, " ") {
|
||||
source = fmt.Sprintf("host=%s", source)
|
||||
@ -142,8 +145,11 @@ func (s *sqlStore) configure() error {
|
||||
// save the values
|
||||
s.db = db
|
||||
|
||||
// get DB
|
||||
database, table := s.getDB(s.options.Database, s.options.Table)
|
||||
|
||||
// initialise the database
|
||||
return s.initDB(s.options.Database, s.options.Table)
|
||||
return s.initDB(database, table)
|
||||
}
|
||||
|
||||
func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
|
||||
@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("unsupported statement")
|
||||
}
|
||||
if len(database) == 0 {
|
||||
database = s.options.Database
|
||||
}
|
||||
if len(table) == 0 {
|
||||
table = s.options.Table
|
||||
}
|
||||
|
||||
// get DB
|
||||
database, table = s.getDB(database, table)
|
||||
|
||||
q := fmt.Sprintf(st, database, table)
|
||||
stmt, err := s.db.Prepare(q)
|
||||
@ -425,7 +428,7 @@ func (s *sqlStore) String() string {
|
||||
func NewStore(opts ...store.Option) store.Store {
|
||||
options := store.Options{
|
||||
Database: DefaultDatabase,
|
||||
Table: DefaultTable,
|
||||
Table: DefaultTable,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
|
Loading…
Reference in New Issue
Block a user