Account for missing options database/table in cockroach store

This commit is contained in:
Asim Aslam 2020-05-01 15:31:55 +01:00
parent 7a2dea6cc2
commit 08a2de1ef5

View File

@ -23,6 +23,8 @@ var (
)
var (
re = regexp.MustCompile("[^a-zA-Z0-9]+")
statements = map[string]string{
"list": "SELECT key, value, expiry FROM %s.%s;",
"read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;",
@ -42,14 +44,33 @@ type sqlStore struct {
databases map[string]bool
}
func (s *sqlStore) createDB(database, table string) {
func (s *sqlStore) getDB(database, table string) (string, string) {
if len(database) == 0 {
database = s.options.Database
if len(s.options.Database) > 0 {
database = s.options.Database
} else {
database = DefaultDatabase
}
}
if len(table) == 0 {
table = s.options.Table
if len(s.options.Table) > 0 {
table = s.options.Table
} else {
database = DefaultTable
}
}
// store.namespace must only contain letters, numbers and underscores
database = re.ReplaceAllString(database, "_")
table = re.ReplaceAllString(table, "_")
return database, table
}
func (s *sqlStore) createDB(database, table string) {
database, table = s.getDB(database, table)
s.Lock()
_, ok := s.databases[database+":"+table]
if !ok {
@ -97,28 +118,10 @@ func (s *sqlStore) configure() error {
s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
}
database := s.options.Database
if len(database) == 0 {
s.options.Database = DefaultDatabase
}
table := s.options.Table
if len(table) == 0 {
s.options.Table = DefaultTable
}
// store.namespace must only contain letters, numbers and underscores
reg, err := regexp.Compile("[^a-zA-Z0-9]+")
if err != nil {
return errors.New("error compiling regex for namespace")
}
database = reg.ReplaceAllString(database, "_")
table = reg.ReplaceAllString(table, "_")
source := s.options.Nodes[0]
// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
_, err = url.Parse(source)
_, err := url.Parse(source)
if err != nil {
if !strings.Contains(source, " ") {
source = fmt.Sprintf("host=%s", source)
@ -142,8 +145,11 @@ func (s *sqlStore) configure() error {
// save the values
s.db = db
// get DB
database, table := s.getDB(s.options.Database, s.options.Table)
// initialise the database
return s.initDB(s.options.Database, s.options.Table)
return s.initDB(database, table)
}
func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
if !ok {
return nil, errors.New("unsupported statement")
}
if len(database) == 0 {
database = s.options.Database
}
if len(table) == 0 {
table = s.options.Table
}
// get DB
database, table = s.getDB(database, table)
q := fmt.Sprintf(st, database, table)
stmt, err := s.db.Prepare(q)
@ -425,7 +428,7 @@ func (s *sqlStore) String() string {
func NewStore(opts ...store.Option) store.Store {
options := store.Options{
Database: DefaultDatabase,
Table: DefaultTable,
Table: DefaultTable,
}
for _, o := range opts {