Account for missing options database/table in cockroach store
This commit is contained in:
		| @@ -23,6 +23,8 @@ var ( | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	re = regexp.MustCompile("[^a-zA-Z0-9]+") | ||||
|  | ||||
| 	statements = map[string]string{ | ||||
| 		"list":       "SELECT key, value, expiry FROM %s.%s;", | ||||
| 		"read":       "SELECT key, value, expiry FROM %s.%s WHERE key = $1;", | ||||
| @@ -42,14 +44,33 @@ type sqlStore struct { | ||||
| 	databases map[string]bool | ||||
| } | ||||
|  | ||||
| func (s *sqlStore) createDB(database, table string) { | ||||
| func (s *sqlStore) getDB(database, table string) (string, string) { | ||||
| 	if len(database) == 0 { | ||||
| 		database = s.options.Database | ||||
| 		if len(s.options.Database) > 0 { | ||||
| 			database = s.options.Database | ||||
| 		} else { | ||||
| 			database = DefaultDatabase | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(table) == 0 { | ||||
| 		table = s.options.Table | ||||
| 		if len(s.options.Table) > 0 { | ||||
| 			table = s.options.Table | ||||
| 		} else { | ||||
| 			database = DefaultTable | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// store.namespace must only contain letters, numbers and underscores | ||||
| 	database = re.ReplaceAllString(database, "_") | ||||
| 	table = re.ReplaceAllString(table, "_") | ||||
|  | ||||
| 	return database, table | ||||
| } | ||||
|  | ||||
| func (s *sqlStore) createDB(database, table string) { | ||||
| 	database, table = s.getDB(database, table) | ||||
|  | ||||
| 	s.Lock() | ||||
| 	_, ok := s.databases[database+":"+table] | ||||
| 	if !ok { | ||||
| @@ -97,28 +118,10 @@ func (s *sqlStore) configure() error { | ||||
| 		s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"} | ||||
| 	} | ||||
|  | ||||
| 	database := s.options.Database | ||||
| 	if len(database) == 0 { | ||||
| 		s.options.Database = DefaultDatabase | ||||
| 	} | ||||
|  | ||||
| 	table := s.options.Table | ||||
| 	if len(table) == 0 { | ||||
| 		s.options.Table = DefaultTable | ||||
| 	} | ||||
|  | ||||
| 	// store.namespace must only contain letters, numbers and underscores | ||||
| 	reg, err := regexp.Compile("[^a-zA-Z0-9]+") | ||||
| 	if err != nil { | ||||
| 		return errors.New("error compiling regex for namespace") | ||||
| 	} | ||||
| 	database = reg.ReplaceAllString(database, "_") | ||||
| 	table = reg.ReplaceAllString(table, "_") | ||||
|  | ||||
| 	source := s.options.Nodes[0] | ||||
| 	// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable | ||||
| 	// if err is nil which means it would be a URL like postgre://xxxx?yy=zz | ||||
| 	_, err = url.Parse(source) | ||||
| 	_, err := url.Parse(source) | ||||
| 	if err != nil { | ||||
| 		if !strings.Contains(source, " ") { | ||||
| 			source = fmt.Sprintf("host=%s", source) | ||||
| @@ -142,8 +145,11 @@ func (s *sqlStore) configure() error { | ||||
| 	// save the values | ||||
| 	s.db = db | ||||
|  | ||||
| 	// get DB | ||||
| 	database, table := s.getDB(s.options.Database, s.options.Table) | ||||
|  | ||||
| 	// initialise the database | ||||
| 	return s.initDB(s.options.Database, s.options.Table) | ||||
| 	return s.initDB(database, table) | ||||
| } | ||||
|  | ||||
| func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) { | ||||
| @@ -151,12 +157,9 @@ func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) { | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("unsupported statement") | ||||
| 	} | ||||
| 	if len(database) == 0 { | ||||
| 		database = s.options.Database | ||||
| 	} | ||||
| 	if len(table) == 0 { | ||||
| 		table = s.options.Table | ||||
| 	} | ||||
|  | ||||
| 	// get DB | ||||
| 	database, table = s.getDB(database, table) | ||||
|  | ||||
| 	q := fmt.Sprintf(st, database, table) | ||||
| 	stmt, err := s.db.Prepare(q) | ||||
| @@ -425,7 +428,7 @@ func (s *sqlStore) String() string { | ||||
| func NewStore(opts ...store.Option) store.Store { | ||||
| 	options := store.Options{ | ||||
| 		Database: DefaultDatabase, | ||||
| 		Table: DefaultTable, | ||||
| 		Table:    DefaultTable, | ||||
| 	} | ||||
|  | ||||
| 	for _, o := range opts { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user