Merge branch 'master' into jwt-auth

This commit is contained in:
ben-toogood 2020-04-29 13:22:09 +01:00 committed by GitHub
commit e57b20c1f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 236 additions and 65 deletions

View File

@ -53,5 +53,5 @@ See the [docs](https://micro.mu/docs/framework.html) for detailed information on
## License ## License
Go Micro is Apache 2.0 licensed Go Micro is Apache 2.0 licensed.

View File

@ -87,7 +87,7 @@ func (d *discordInput) Start() error {
} }
var err error var err error
d.session, err = discordgo.New(d.token) d.session, err = discordgo.New("Bot " + d.token)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,11 +3,8 @@ package auth
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"time" "time"
"github.com/micro/go-micro/v2/metadata"
) )
var ( var (
@ -90,44 +87,24 @@ type Token struct {
const ( const (
// DefaultNamespace used for auth // DefaultNamespace used for auth
DefaultNamespace = "go.micro" DefaultNamespace = "go.micro"
// MetadataKey is the key used when storing the account in metadata
MetadataKey = "auth-account"
// TokenCookieName is the name of the cookie which stores the auth token // TokenCookieName is the name of the cookie which stores the auth token
TokenCookieName = "micro-token" TokenCookieName = "micro-token"
// SecretCookieName is the name of the cookie which stores the auth secret
SecretCookieName = "micro-secret"
// BearerScheme used for Authorization header // BearerScheme used for Authorization header
BearerScheme = "Bearer " BearerScheme = "Bearer "
) )
type accountKey struct{}
// AccountFromContext gets the account from the context, which // AccountFromContext gets the account from the context, which
// is set by the auth wrapper at the start of a call. If the account // is set by the auth wrapper at the start of a call. If the account
// is not set, a nil account will be returned. The error is only returned // is not set, a nil account will be returned. The error is only returned
// when there was a problem retrieving an account // when there was a problem retrieving an account
func AccountFromContext(ctx context.Context) (*Account, error) { func AccountFromContext(ctx context.Context) (*Account, bool) {
str, ok := metadata.Get(ctx, MetadataKey) acc, ok := ctx.Value(accountKey{}).(*Account)
// there was no account set return acc, ok
if !ok {
return nil, nil
}
var acc *Account
// metadata is stored as a string, so unmarshal to an account
if err := json.Unmarshal([]byte(str), &acc); err != nil {
return nil, err
}
return acc, nil
} }
// ContextWithAccount sets the account in the context // ContextWithAccount sets the account in the context
func ContextWithAccount(ctx context.Context, account *Account) (context.Context, error) { func ContextWithAccount(ctx context.Context, account *Account) context.Context {
// metadata is stored as a string, so marshal to bytes return context.WithValue(ctx, accountKey{}, account)
bytes, err := json.Marshal(account)
if err != nil {
return ctx, err
}
// generate a new context with the MetadataKey set
return metadata.Set(ctx, MetadataKey, string(bytes)), nil
} }

View File

@ -49,6 +49,13 @@ type Option func(*Options)
type PublishOption func(*PublishOptions) type PublishOption func(*PublishOptions)
// PublishContext set context
func PublishContext(ctx context.Context) PublishOption {
return func(o *PublishOptions) {
o.Context = ctx
}
}
type SubscribeOption func(*SubscribeOptions) type SubscribeOption func(*SubscribeOptions)
func NewSubscribeOptions(opts ...SubscribeOption) SubscribeOptions { func NewSubscribeOptions(opts ...SubscribeOption) SubscribeOptions {

View File

@ -653,7 +653,7 @@ func (g *grpcClient) Publish(ctx context.Context, p client.Message, opts ...clie
return g.opts.Broker.Publish(topic, &broker.Message{ return g.opts.Broker.Publish(topic, &broker.Message{
Header: md, Header: md,
Body: body, Body: body,
}) }, broker.PublishContext(options.Context))
} }
func (g *grpcClient) String() string { func (g *grpcClient) String() string {

View File

@ -252,6 +252,13 @@ func WithExchange(e string) PublishOption {
} }
} }
// PublishContext sets the context in publish options
func PublishContext(ctx context.Context) PublishOption {
return func(o *PublishOptions) {
o.Context = ctx
}
}
// WithAddress sets the remote addresses to use rather than using service discovery // WithAddress sets the remote addresses to use rather than using service discovery
func WithAddress(a ...string) CallOption { func WithAddress(a ...string) CallOption {
return func(o *CallOptions) { return func(o *CallOptions) {

View File

@ -628,7 +628,7 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
return r.opts.Broker.Publish(topic, &broker.Message{ return r.opts.Broker.Publish(topic, &broker.Message{
Header: md, Header: md,
Body: body, Body: body,
}) }, broker.PublishContext(options.Context))
} }
func (r *rpcClient) NewMessage(topic string, message interface{}, opts ...MessageOption) Message { func (r *rpcClient) NewMessage(topic string, message interface{}, opts ...MessageOption) Message {

View File

@ -55,11 +55,17 @@ func NewRuntime(opts ...Option) Runtime {
// @todo move this to runtime default // @todo move this to runtime default
func (r *runtime) checkoutSourceIfNeeded(s *Service) error { func (r *runtime) checkoutSourceIfNeeded(s *Service) error {
// Runtime service like config have no source.
// Skip checkout in that case
if len(s.Source) == 0 {
return nil
}
source, err := git.ParseSourceLocal("", s.Source) source, err := git.ParseSourceLocal("", s.Source)
if err != nil { if err != nil {
return err return err
} }
source.Ref = s.Version source.Ref = s.Version
err = git.CheckoutSource(os.TempDir(), source) err = git.CheckoutSource(os.TempDir(), source)
if err != nil { if err != nil {
return err return err
@ -209,7 +215,10 @@ func serviceKey(s *Service) string {
// Create creates a new service which is then started by runtime // Create creates a new service which is then started by runtime
func (r *runtime) Create(s *Service, opts ...CreateOption) error { func (r *runtime) Create(s *Service, opts ...CreateOption) error {
r.checkoutSourceIfNeeded(s) err := r.checkoutSourceIfNeeded(s)
if err != nil {
return err
}
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
@ -251,6 +260,18 @@ func (r *runtime) Create(s *Service, opts ...CreateOption) error {
return nil return nil
} }
// exists returns whether the given file or directory exists
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return true, err
}
// @todo: Getting existing lines is not supported yet. // @todo: Getting existing lines is not supported yet.
// The reason for this is because it's hard to calculate line offset // The reason for this is because it's hard to calculate line offset
// as opposed to character offset. // as opposed to character offset.
@ -265,18 +286,53 @@ func (r *runtime) Logs(s *Service, options ...LogsOption) (LogStream, error) {
stream: make(chan LogRecord), stream: make(chan LogRecord),
stop: make(chan bool), stop: make(chan bool),
} }
t, err := tail.TailFile(logFile(s.Name), tail.Config{Follow: true, Location: &tail.SeekInfo{
Whence: 2, fpath := logFile(s.Name)
Offset: 0, if ex, err := exists(fpath); err != nil {
return nil, err
} else if !ex {
return nil, fmt.Errorf("Log file %v does not exists", fpath)
}
// have to check file size to avoid too big of a seek
fi, err := os.Stat(fpath)
if err != nil {
return nil, err
}
size := fi.Size()
whence := 2
// Multiply by length of an average line of log in bytes
offset := lopts.Count * 200
if offset > size {
offset = size
}
offset *= -1
t, err := tail.TailFile(fpath, tail.Config{Follow: lopts.Stream, Location: &tail.SeekInfo{
Whence: whence,
Offset: int64(offset),
}, Logger: tail.DiscardingLogger}) }, Logger: tail.DiscardingLogger})
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret.tail = t ret.tail = t
go func() { go func() {
for line := range t.Lines { for {
ret.stream <- LogRecord{Message: line.Text} select {
case line, ok := <-t.Lines:
if !ok {
ret.Stop()
return
}
ret.stream <- LogRecord{Message: line.Text}
case <-ret.stop:
return
}
} }
}() }()
return ret, nil return ret, nil
} }
@ -301,16 +357,18 @@ func (l *logStream) Error() error {
func (l *logStream) Stop() error { func (l *logStream) Stop() error {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
// @todo seems like this is causing a hangup
//err := l.tail.Stop()
//if err != nil {
// return err
//}
select { select {
case <-l.stop: case <-l.stop:
return nil return nil
default: default:
close(l.stop) close(l.stop)
close(l.stream)
err := l.tail.Stop()
if err != nil {
logger.Errorf("Error stopping tail: %v", err)
return err
}
} }
return nil return nil
} }
@ -353,14 +411,17 @@ func (r *runtime) Read(opts ...ReadOption) ([]*Service, error) {
// Update attemps to update the service // Update attemps to update the service
func (r *runtime) Update(s *Service, opts ...UpdateOption) error { func (r *runtime) Update(s *Service, opts ...UpdateOption) error {
r.checkoutSourceIfNeeded(s) err := r.checkoutSourceIfNeeded(s)
if err != nil {
return err
}
r.Lock() r.Lock()
service, ok := r.services[serviceKey(s)] service, ok := r.services[serviceKey(s)]
r.Unlock() r.Unlock()
if !ok { if !ok {
return errors.New("Service not found") return errors.New("Service not found")
} }
err := service.Stop() err = service.Stop()
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package git package git
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
@ -105,7 +106,7 @@ type binaryGitter struct {
} }
func (g binaryGitter) Clone(repo string) error { func (g binaryGitter) Clone(repo string) error {
fold := filepath.Join(g.folder, dirifyRepo(repo)) fold := filepath.Join(g.folder, dirifyRepo(repo), ".git")
exists, err := pathExists(fold) exists, err := pathExists(fold)
if err != nil { if err != nil {
return err return err
@ -113,6 +114,7 @@ func (g binaryGitter) Clone(repo string) error {
if exists { if exists {
return nil return nil
} }
fold = filepath.Join(g.folder, dirifyRepo(repo))
cmd := exec.Command("git", "clone", repo, ".") cmd := exec.Command("git", "clone", repo, ".")
err = os.MkdirAll(fold, 0777) err = os.MkdirAll(fold, 0777)
@ -130,9 +132,9 @@ func (g binaryGitter) Clone(repo string) error {
func (g binaryGitter) FetchAll(repo string) error { func (g binaryGitter) FetchAll(repo string) error {
cmd := exec.Command("git", "fetch", "--all") cmd := exec.Command("git", "fetch", "--all")
cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo)) cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo))
_, err := cmd.Output() outp, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return err return errors.New(string(outp))
} }
return err return err
} }
@ -143,9 +145,9 @@ func (g binaryGitter) Checkout(repo, branchOrCommit string) error {
} }
cmd := exec.Command("git", "checkout", "-f", branchOrCommit) cmd := exec.Command("git", "checkout", "-f", branchOrCommit)
cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo)) cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo))
_, err := cmd.Output() outp, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return err return errors.New(string(outp))
} }
return nil return nil
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/micro/go-micro/v2/client" "github.com/micro/go-micro/v2/client"
"github.com/micro/go-micro/v2/runtime" "github.com/micro/go-micro/v2/runtime"
pb "github.com/micro/go-micro/v2/runtime/service/proto" pb "github.com/micro/go-micro/v2/runtime/service/proto"
"github.com/micro/go-micro/v2/util/log"
) )
type svc struct { type svc struct {
@ -72,14 +71,15 @@ func (s *svc) Logs(service *runtime.Service, opts ...runtime.LogsOption) (runtim
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
if options.Context == nil { if options.Context == nil {
options.Context = context.Background() options.Context = context.Background()
} }
ls, err := s.runtime.Logs(options.Context, &pb.LogsRequest{ ls, err := s.runtime.Logs(options.Context, &pb.LogsRequest{
Service: service.Name, Service: service.Name,
Stream: true, Stream: options.Stream,
Count: 10, // @todo pass in actual options Count: options.Count,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,14 +89,39 @@ func (s *svc) Logs(service *runtime.Service, opts ...runtime.LogsOption) (runtim
stream: make(chan runtime.LogRecord), stream: make(chan runtime.LogRecord),
stop: make(chan bool), stop: make(chan bool),
} }
go func() { go func() {
for { for {
record := runtime.LogRecord{} select {
err := ls.RecvMsg(&record) // @todo this never seems to return, investigate
if err != nil { case <-ls.Context().Done():
log.Error(err) logStream.Stop()
}
}
}()
go func() {
for {
select {
// @todo this never seems to return, investigate
case <-ls.Context().Done():
return
case _, ok := <-logStream.stream:
if !ok {
return
}
default:
record := pb.LogRecord{}
err := ls.RecvMsg(&record)
if err != nil {
logStream.Stop()
return
}
logStream.stream <- runtime.LogRecord{
Message: record.GetMessage(),
Metadata: record.GetMetadata(),
}
} }
logStream.stream <- record
} }
}() }()
return logStream, nil return logStream, nil
@ -125,6 +150,7 @@ func (l *serviceLogStream) Stop() error {
case <-l.stop: case <-l.stop:
return nil return nil
default: default:
close(l.stream)
close(l.stop) close(l.stop)
} }
return nil return nil

View File

@ -178,10 +178,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
} }
// There is an account, set it in the context // There is an account, set it in the context
ctx, err = auth.ContextWithAccount(ctx, account) ctx = auth.ContextWithAccount(ctx, account)
if err != nil {
return err
}
// The user is authorised, allow the call // The user is authorised, allow the call
return h(ctx, req, rsp) return h(ctx, req, rsp)

View File

@ -111,6 +111,9 @@ func (s *service) run(exit chan bool) {
} }
func (s *service) register() error { func (s *service) register() error {
s.RLock()
defer s.RUnlock()
if s.srv == nil { if s.srv == nil {
return nil return nil
} }
@ -138,6 +141,9 @@ func (s *service) register() error {
} }
func (s *service) deregister() error { func (s *service) deregister() error {
s.RLock()
defer s.RUnlock()
if s.srv == nil { if s.srv == nil {
return nil return nil
} }
@ -280,18 +286,22 @@ func (s *service) Client() *http.Client {
func (s *service) Handle(pattern string, handler http.Handler) { func (s *service) Handle(pattern string, handler http.Handler) {
var seen bool var seen bool
s.RLock()
for _, ep := range s.srv.Endpoints { for _, ep := range s.srv.Endpoints {
if ep.Name == pattern { if ep.Name == pattern {
seen = true seen = true
break break
} }
} }
s.RUnlock()
// if its unseen then add an endpoint // if its unseen then add an endpoint
if !seen { if !seen {
s.Lock()
s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{ s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
Name: pattern, Name: pattern,
}) })
s.Unlock()
} }
// disable static serving // disable static serving
@ -306,17 +316,23 @@ func (s *service) Handle(pattern string, handler http.Handler) {
} }
func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
var seen bool var seen bool
s.RLock()
for _, ep := range s.srv.Endpoints { for _, ep := range s.srv.Endpoints {
if ep.Name == pattern { if ep.Name == pattern {
seen = true seen = true
break break
} }
} }
s.RUnlock()
if !seen { if !seen {
s.Lock()
s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{ s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
Name: pattern, Name: pattern,
}) })
s.Unlock()
} }
// disable static serving // disable static serving
@ -331,7 +347,6 @@ func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *
func (s *service) Init(opts ...Option) error { func (s *service) Init(opts ...Option) error {
s.Lock() s.Lock()
defer s.Unlock()
for _, o := range opts { for _, o := range opts {
o(&s.opts) o(&s.opts)
@ -347,6 +362,8 @@ func (s *service) Init(opts ...Option) error {
serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry)) serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
} }
s.Unlock()
serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error { serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -386,14 +403,19 @@ func (s *service) Init(opts ...Option) error {
return nil return nil
})) }))
s.RLock()
// pass in own name and version // pass in own name and version
serviceOpts = append(serviceOpts, micro.Name(s.opts.Name)) serviceOpts = append(serviceOpts, micro.Name(s.opts.Name))
serviceOpts = append(serviceOpts, micro.Version(s.opts.Version)) serviceOpts = append(serviceOpts, micro.Version(s.opts.Version))
s.RUnlock()
s.opts.Service.Init(serviceOpts...) s.opts.Service.Init(serviceOpts...)
s.Lock()
srv := s.genSrv() srv := s.genSrv()
srv.Endpoints = s.srv.Endpoints srv.Endpoints = s.srv.Endpoints
s.srv = srv s.srv = srv
s.Unlock()
return nil return nil
} }

72
web/web_test.go Normal file
View File

@ -0,0 +1,72 @@
package web_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/micro/cli/v2"
"github.com/micro/go-micro/v2"
"github.com/micro/go-micro/v2/logger"
"github.com/micro/go-micro/v2/web"
)
func TestWeb(t *testing.T) {
for i := 0; i < 10; i++ {
fmt.Println("Test nr", i)
testFunc()
}
}
func testFunc() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*250)
defer cancel()
s := micro.NewService(
micro.Name("test"),
micro.Context(ctx),
micro.HandleSignal(false),
micro.Flags(
&cli.StringFlag{
Name: "test.timeout",
},
&cli.BoolFlag{
Name: "test.v",
},
&cli.StringFlag{
Name: "test.run",
},
&cli.StringFlag{
Name: "test.testlogfile",
},
),
)
w := web.NewService(
web.MicroService(s),
web.Context(ctx),
web.HandleSignal(false),
)
//s.Init()
//w.Init()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err := s.Run()
if err != nil {
logger.Errorf("micro run error: %v", err)
}
}()
go func() {
defer wg.Done()
err := w.Run()
if err != nil {
logger.Errorf("web run error: %v", err)
}
}()
wg.Wait()
}