// Package cockroach implements the cockroach store
package cockroach

import (
	"database/sql"
	"fmt"
	"strings"
	"time"
	"unicode"

	"github.com/lib/pq"
	"github.com/micro/go-micro/store"
	"github.com/micro/go-micro/util/log"
	"github.com/pkg/errors"
)

// DefaultNamespace is the namespace that the sql store
// will use if no namespace is provided.
var (
	DefaultNamespace = "micro"
	DefaultPrefix    = "micro"
)

type sqlStore struct {
	db *sql.DB

	database string
	table    string

	options store.Options
}

// List all the known records
func (s *sqlStore) List() ([]*store.Record, error) {
	rows, err := s.db.Query(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s;", s.database, s.table))
	var records []*store.Record
	var timehelper pq.NullTime
	if err != nil {
		if err == sql.ErrNoRows {
			return records, nil
		}
		return nil, err
	}
	defer rows.Close()
	for rows.Next() {
		record := &store.Record{}
		if err := rows.Scan(&record.Key, &record.Value, &timehelper); err != nil {
			return records, err
		}
		if timehelper.Valid {
			if timehelper.Time.Before(time.Now()) {
				// record has expired
				go s.Delete(record.Key)
			} else {
				record.Expiry = time.Until(timehelper.Time)
				records = append(records, record)
			}
		} else {
			records = append(records, record)
		}

	}
	rowErr := rows.Close()
	if rowErr != nil {
		// transaction rollback or something
		return records, rowErr
	}
	if err := rows.Err(); err != nil {
		return records, err
	}
	return records, nil
}

// Read all records with keys
func (s *sqlStore) Read(keys ...string) ([]*store.Record, error) {
	q, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key = $1;", s.database, s.table))
	if err != nil {
		return nil, err
	}
	var records []*store.Record
	var timehelper pq.NullTime
	for _, key := range keys {
		row := q.QueryRow(key)
		record := &store.Record{}
		if err := row.Scan(&record.Key, &record.Value, &timehelper); err != nil {
			if err == sql.ErrNoRows {
				return records, store.ErrNotFound
			}
			return records, err
		}
		if timehelper.Valid {
			if timehelper.Time.Before(time.Now()) {
				// record has expired
				go s.Delete(key)
				return records, store.ErrNotFound
			}
			record.Expiry = time.Until(timehelper.Time)
			records = append(records, record)
		} else {
			records = append(records, record)
		}
	}
	return records, nil
}

// Write records
func (s *sqlStore) Write(rec ...*store.Record) error {
	q, err := s.db.Prepare(fmt.Sprintf(`INSERT INTO %s.%s(key, value, expiry)
		VALUES ($1, $2::bytea, $3)
		ON CONFLICT (key)
		DO UPDATE
		SET value = EXCLUDED.value, expiry = EXCLUDED.expiry;`, s.database, s.table))
	if err != nil {
		return err
	}
	for _, r := range rec {
		var err error
		if r.Expiry != 0 {
			_, err = q.Exec(r.Key, r.Value, time.Now().Add(r.Expiry))
		} else {
			_, err = q.Exec(r.Key, r.Value, nil)
		}
		if err != nil {
			return errors.Wrap(err, "Couldn't insert record "+r.Key)
		}
	}

	return nil
}

// Delete records with keys
func (s *sqlStore) Delete(keys ...string) error {
	q, err := s.db.Prepare(fmt.Sprintf("DELETE FROM %s.%s WHERE key = $1;", s.database, s.table))
	if err != nil {
		return err
	}
	for _, key := range keys {
		result, err := q.Exec(key)
		if err != nil {
			return err
		}
		_, err = result.RowsAffected()
		if err != nil {
			return err
		}
	}
	return nil
}

func (s *sqlStore) initDB() {
	// Create the namespace's database
	_, err := s.db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s ;", s.database))
	if err != nil {
		log.Fatal(err)
	}

	_, err = s.db.Exec(fmt.Sprintf("SET DATABASE = %s ;", s.database))
	if err != nil {
		log.Fatal(errors.Wrap(err, "Couldn't set database"))
	}

	// Create a table for the namespace's prefix
	_, err = s.db.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s
	(
		key text NOT NULL,
		value bytea,
		expiry timestamp with time zone,
		CONSTRAINT %s_pkey PRIMARY KEY (key)
	);`, s.table, s.table))
	if err != nil {
		log.Fatal(errors.Wrap(err, "Couldn't create table"))
	}
}

// New returns a new micro Store backed by sql
func New(opts ...store.Option) store.Store {
	var options store.Options
	for _, o := range opts {
		o(&options)
	}

	nodes := options.Nodes
	if len(nodes) == 0 {
		nodes = []string{"localhost:26257"}
	}

	namespace := options.Namespace
	if len(namespace) == 0 {
		namespace = DefaultNamespace
	}

	prefix := options.Prefix
	if len(prefix) == 0 {
		prefix = DefaultPrefix
	}

	for _, r := range namespace {
		if !unicode.IsLetter(r) {
			log.Fatal("store.namespace must only contain letters")
		}
	}

	source := nodes[0]
	if !strings.Contains(source, " ") {
		source = fmt.Sprintf("host=%s", source)
	}
	// create source from first node
	db, err := sql.Open("postgres", source)
	if err != nil {
		log.Fatal(err)
	}

	if err := db.Ping(); err != nil {
		log.Fatal(err)
	}

	s := &sqlStore{
		db:       db,
		database: namespace,
		table:    prefix,
	}
	s.initDB()
	return s
}