From 08a2de1ef50a7cd0e59b92e250f4dda41987de8d Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Fri, 1 May 2020 15:31:55 +0100 Subject: [PATCH] Account for missing options database/table in cockroach store --- store/cockroach/cockroach.go | 63 +++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/store/cockroach/cockroach.go b/store/cockroach/cockroach.go index 397db67b..f55abbb4 100644 --- a/store/cockroach/cockroach.go +++ b/store/cockroach/cockroach.go @@ -23,6 +23,8 @@ var ( ) var ( + re = regexp.MustCompile("[^a-zA-Z0-9]+") + statements = map[string]string{ "list": "SELECT key, value, expiry FROM %s.%s;", "read": "SELECT key, value, expiry FROM %s.%s WHERE key = $1;", @@ -42,14 +44,33 @@ type sqlStore struct { databases map[string]bool } -func (s *sqlStore) createDB(database, table string) { +func (s *sqlStore) getDB(database, table string) (string, string) { if len(database) == 0 { - database = s.options.Database + if len(s.options.Database) > 0 { + database = s.options.Database + } else { + database = DefaultDatabase + } } + if len(table) == 0 { - table = s.options.Table + if len(s.options.Table) > 0 { + table = s.options.Table + } else { + database = DefaultTable + } } + // store.namespace must only contain letters, numbers and underscores + database = re.ReplaceAllString(database, "_") + table = re.ReplaceAllString(table, "_") + + return database, table +} + +func (s *sqlStore) createDB(database, table string) { + database, table = s.getDB(database, table) + s.Lock() _, ok := s.databases[database+":"+table] if !ok { @@ -97,28 +118,10 @@ func (s *sqlStore) configure() error { s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"} } - database := s.options.Database - if len(database) == 0 { - s.options.Database = DefaultDatabase - } - - table := s.options.Table - if len(table) == 0 { - s.options.Table = DefaultTable - } - - // store.namespace must only contain letters, numbers and underscores - reg, err := regexp.Compile("[^a-zA-Z0-9]+") - if err != nil { - return errors.New("error compiling regex for namespace") - } - database = reg.ReplaceAllString(database, "_") - table = reg.ReplaceAllString(table, "_") - source := s.options.Nodes[0] // check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable // if err is nil which means it would be a URL like postgre://xxxx?yy=zz - _, err = url.Parse(source) + _, err := url.Parse(source) if err != nil { if !strings.Contains(source, " ") { source = fmt.Sprintf("host=%s", source) @@ -142,8 +145,11 @@ func (s *sqlStore) configure() error { // save the values s.db = db + // get DB + database, table := s.getDB(s.options.Database, s.options.Table) + // initialise the database - return s.initDB(s.options.Database, s.options.Table) + return s.initDB(database, table) } func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) { @@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) { if !ok { return nil, errors.New("unsupported statement") } - if len(database) == 0 { - database = s.options.Database - } - if len(table) == 0 { - table = s.options.Table - } + + // get DB + database, table = s.getDB(database, table) q := fmt.Sprintf(st, database, table) stmt, err := s.db.Prepare(q) @@ -425,7 +428,7 @@ func (s *sqlStore) String() string { func NewStore(opts ...store.Option) store.Store { options := store.Options{ Database: DefaultDatabase, - Table: DefaultTable, + Table: DefaultTable, } for _, o := range opts {