Account for missing options database/table in cockroach store

This commit is contained in:
Asim Aslam 2020-05-01 15:31:55 +01:00
parent 7a2dea6cc2
commit 08a2de1ef5

View File

@ -23,6 +23,8 @@ var (
) )
var ( var (
re = regexp.MustCompile("[^a-zA-Z0-9]+")
statements = map[string]string{ statements = map[string]string{
"list": "SELECT key, value, expiry FROM %s.%s;", "list": "SELECT key, value, expiry FROM %s.%s;",
"read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;", "read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;",
@ -42,14 +44,33 @@ type sqlStore struct {
databases map[string]bool databases map[string]bool
} }
func (s *sqlStore) createDB(database, table string) { func (s *sqlStore) getDB(database, table string) (string, string) {
if len(database) == 0 { if len(database) == 0 {
database = s.options.Database if len(s.options.Database) > 0 {
database = s.options.Database
} else {
database = DefaultDatabase
}
} }
if len(table) == 0 { if len(table) == 0 {
table = s.options.Table if len(s.options.Table) > 0 {
table = s.options.Table
} else {
database = DefaultTable
}
} }
// store.namespace must only contain letters, numbers and underscores
database = re.ReplaceAllString(database, "_")
table = re.ReplaceAllString(table, "_")
return database, table
}
func (s *sqlStore) createDB(database, table string) {
database, table = s.getDB(database, table)
s.Lock() s.Lock()
_, ok := s.databases[database+":"+table] _, ok := s.databases[database+":"+table]
if !ok { if !ok {
@ -97,28 +118,10 @@ func (s *sqlStore) configure() error {
s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"} s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
} }
database := s.options.Database
if len(database) == 0 {
s.options.Database = DefaultDatabase
}
table := s.options.Table
if len(table) == 0 {
s.options.Table = DefaultTable
}
// store.namespace must only contain letters, numbers and underscores
reg, err := regexp.Compile("[^a-zA-Z0-9]+")
if err != nil {
return errors.New("error compiling regex for namespace")
}
database = reg.ReplaceAllString(database, "_")
table = reg.ReplaceAllString(table, "_")
source := s.options.Nodes[0] source := s.options.Nodes[0]
// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable // check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
// if err is nil which means it would be a URL like postgre://xxxx?yy=zz // if err is nil which means it would be a URL like postgre://xxxx?yy=zz
_, err = url.Parse(source) _, err := url.Parse(source)
if err != nil { if err != nil {
if !strings.Contains(source, " ") { if !strings.Contains(source, " ") {
source = fmt.Sprintf("host=%s", source) source = fmt.Sprintf("host=%s", source)
@ -142,8 +145,11 @@ func (s *sqlStore) configure() error {
// save the values // save the values
s.db = db s.db = db
// get DB
database, table := s.getDB(s.options.Database, s.options.Table)
// initialise the database // initialise the database
return s.initDB(s.options.Database, s.options.Table) return s.initDB(database, table)
} }
func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) { func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
if !ok { if !ok {
return nil, errors.New("unsupported statement") return nil, errors.New("unsupported statement")
} }
if len(database) == 0 {
database = s.options.Database // get DB
} database, table = s.getDB(database, table)
if len(table) == 0 {
table = s.options.Table
}
q := fmt.Sprintf(st, database, table) q := fmt.Sprintf(st, database, table)
stmt, err := s.db.Prepare(q) stmt, err := s.db.Prepare(q)
@ -425,7 +428,7 @@ func (s *sqlStore) String() string {
func NewStore(opts ...store.Option) store.Store { func NewStore(opts ...store.Option) store.Store {
options := store.Options{ options := store.Options{
Database: DefaultDatabase, Database: DefaultDatabase,
Table: DefaultTable, Table: DefaultTable,
} }
for _, o := range opts { for _, o := range opts {