Merge branch 'master' into jwt-auth

This commit is contained in:
ben-toogood 2020-04-29 13:22:09 +01:00 committed by GitHub
commit e57b20c1f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 236 additions and 65 deletions

View File

@ -53,5 +53,5 @@ See the [docs](https://micro.mu/docs/framework.html) for detailed information on
## License
Go Micro is Apache 2.0 licensed
Go Micro is Apache 2.0 licensed.

View File

@ -87,7 +87,7 @@ func (d *discordInput) Start() error {
}
var err error
d.session, err = discordgo.New(d.token)
d.session, err = discordgo.New("Bot " + d.token)
if err != nil {
return err
}

View File

@ -3,11 +3,8 @@ package auth
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/micro/go-micro/v2/metadata"
)
var (
@ -90,44 +87,24 @@ type Token struct {
const (
// DefaultNamespace used for auth
DefaultNamespace = "go.micro"
// MetadataKey is the key used when storing the account in metadata
MetadataKey = "auth-account"
// TokenCookieName is the name of the cookie which stores the auth token
TokenCookieName = "micro-token"
// SecretCookieName is the name of the cookie which stores the auth secret
SecretCookieName = "micro-secret"
// BearerScheme used for Authorization header
BearerScheme = "Bearer "
)
type accountKey struct{}
// AccountFromContext gets the account from the context, which
// is set by the auth wrapper at the start of a call. If the account
// is not set, a nil account will be returned. The error is only returned
// when there was a problem retrieving an account
func AccountFromContext(ctx context.Context) (*Account, error) {
str, ok := metadata.Get(ctx, MetadataKey)
// there was no account set
if !ok {
return nil, nil
}
var acc *Account
// metadata is stored as a string, so unmarshal to an account
if err := json.Unmarshal([]byte(str), &acc); err != nil {
return nil, err
}
return acc, nil
func AccountFromContext(ctx context.Context) (*Account, bool) {
acc, ok := ctx.Value(accountKey{}).(*Account)
return acc, ok
}
// ContextWithAccount sets the account in the context
func ContextWithAccount(ctx context.Context, account *Account) (context.Context, error) {
// metadata is stored as a string, so marshal to bytes
bytes, err := json.Marshal(account)
if err != nil {
return ctx, err
}
// generate a new context with the MetadataKey set
return metadata.Set(ctx, MetadataKey, string(bytes)), nil
func ContextWithAccount(ctx context.Context, account *Account) context.Context {
return context.WithValue(ctx, accountKey{}, account)
}

View File

@ -49,6 +49,13 @@ type Option func(*Options)
type PublishOption func(*PublishOptions)
// PublishContext set context
func PublishContext(ctx context.Context) PublishOption {
return func(o *PublishOptions) {
o.Context = ctx
}
}
type SubscribeOption func(*SubscribeOptions)
func NewSubscribeOptions(opts ...SubscribeOption) SubscribeOptions {

View File

@ -653,7 +653,7 @@ func (g *grpcClient) Publish(ctx context.Context, p client.Message, opts ...clie
return g.opts.Broker.Publish(topic, &broker.Message{
Header: md,
Body: body,
})
}, broker.PublishContext(options.Context))
}
func (g *grpcClient) String() string {

View File

@ -252,6 +252,13 @@ func WithExchange(e string) PublishOption {
}
}
// PublishContext sets the context in publish options
func PublishContext(ctx context.Context) PublishOption {
return func(o *PublishOptions) {
o.Context = ctx
}
}
// WithAddress sets the remote addresses to use rather than using service discovery
func WithAddress(a ...string) CallOption {
return func(o *CallOptions) {

View File

@ -628,7 +628,7 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
return r.opts.Broker.Publish(topic, &broker.Message{
Header: md,
Body: body,
})
}, broker.PublishContext(options.Context))
}
func (r *rpcClient) NewMessage(topic string, message interface{}, opts ...MessageOption) Message {

View File

@ -55,11 +55,17 @@ func NewRuntime(opts ...Option) Runtime {
// @todo move this to runtime default
func (r *runtime) checkoutSourceIfNeeded(s *Service) error {
// Runtime service like config have no source.
// Skip checkout in that case
if len(s.Source) == 0 {
return nil
}
source, err := git.ParseSourceLocal("", s.Source)
if err != nil {
return err
}
source.Ref = s.Version
err = git.CheckoutSource(os.TempDir(), source)
if err != nil {
return err
@ -209,7 +215,10 @@ func serviceKey(s *Service) string {
// Create creates a new service which is then started by runtime
func (r *runtime) Create(s *Service, opts ...CreateOption) error {
r.checkoutSourceIfNeeded(s)
err := r.checkoutSourceIfNeeded(s)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
@ -251,6 +260,18 @@ func (r *runtime) Create(s *Service, opts ...CreateOption) error {
return nil
}
// exists returns whether the given file or directory exists
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return true, err
}
// @todo: Getting existing lines is not supported yet.
// The reason for this is because it's hard to calculate line offset
// as opposed to character offset.
@ -265,18 +286,53 @@ func (r *runtime) Logs(s *Service, options ...LogsOption) (LogStream, error) {
stream: make(chan LogRecord),
stop: make(chan bool),
}
t, err := tail.TailFile(logFile(s.Name), tail.Config{Follow: true, Location: &tail.SeekInfo{
Whence: 2,
Offset: 0,
fpath := logFile(s.Name)
if ex, err := exists(fpath); err != nil {
return nil, err
} else if !ex {
return nil, fmt.Errorf("Log file %v does not exists", fpath)
}
// have to check file size to avoid too big of a seek
fi, err := os.Stat(fpath)
if err != nil {
return nil, err
}
size := fi.Size()
whence := 2
// Multiply by length of an average line of log in bytes
offset := lopts.Count * 200
if offset > size {
offset = size
}
offset *= -1
t, err := tail.TailFile(fpath, tail.Config{Follow: lopts.Stream, Location: &tail.SeekInfo{
Whence: whence,
Offset: int64(offset),
}, Logger: tail.DiscardingLogger})
if err != nil {
return nil, err
}
ret.tail = t
go func() {
for line := range t.Lines {
ret.stream <- LogRecord{Message: line.Text}
for {
select {
case line, ok := <-t.Lines:
if !ok {
ret.Stop()
return
}
ret.stream <- LogRecord{Message: line.Text}
case <-ret.stop:
return
}
}
}()
return ret, nil
}
@ -301,16 +357,18 @@ func (l *logStream) Error() error {
func (l *logStream) Stop() error {
l.Lock()
defer l.Unlock()
// @todo seems like this is causing a hangup
//err := l.tail.Stop()
//if err != nil {
// return err
//}
select {
case <-l.stop:
return nil
default:
close(l.stop)
close(l.stream)
err := l.tail.Stop()
if err != nil {
logger.Errorf("Error stopping tail: %v", err)
return err
}
}
return nil
}
@ -353,14 +411,17 @@ func (r *runtime) Read(opts ...ReadOption) ([]*Service, error) {
// Update attemps to update the service
func (r *runtime) Update(s *Service, opts ...UpdateOption) error {
r.checkoutSourceIfNeeded(s)
err := r.checkoutSourceIfNeeded(s)
if err != nil {
return err
}
r.Lock()
service, ok := r.services[serviceKey(s)]
r.Unlock()
if !ok {
return errors.New("Service not found")
}
err := service.Stop()
err = service.Stop()
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package git
import (
"errors"
"fmt"
"os"
"os/exec"
@ -105,7 +106,7 @@ type binaryGitter struct {
}
func (g binaryGitter) Clone(repo string) error {
fold := filepath.Join(g.folder, dirifyRepo(repo))
fold := filepath.Join(g.folder, dirifyRepo(repo), ".git")
exists, err := pathExists(fold)
if err != nil {
return err
@ -113,6 +114,7 @@ func (g binaryGitter) Clone(repo string) error {
if exists {
return nil
}
fold = filepath.Join(g.folder, dirifyRepo(repo))
cmd := exec.Command("git", "clone", repo, ".")
err = os.MkdirAll(fold, 0777)
@ -130,9 +132,9 @@ func (g binaryGitter) Clone(repo string) error {
func (g binaryGitter) FetchAll(repo string) error {
cmd := exec.Command("git", "fetch", "--all")
cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo))
_, err := cmd.Output()
outp, err := cmd.CombinedOutput()
if err != nil {
return err
return errors.New(string(outp))
}
return err
}
@ -143,9 +145,9 @@ func (g binaryGitter) Checkout(repo, branchOrCommit string) error {
}
cmd := exec.Command("git", "checkout", "-f", branchOrCommit)
cmd.Dir = filepath.Join(g.folder, dirifyRepo(repo))
_, err := cmd.Output()
outp, err := cmd.CombinedOutput()
if err != nil {
return err
return errors.New(string(outp))
}
return nil
}

View File

@ -7,7 +7,6 @@ import (
"github.com/micro/go-micro/v2/client"
"github.com/micro/go-micro/v2/runtime"
pb "github.com/micro/go-micro/v2/runtime/service/proto"
"github.com/micro/go-micro/v2/util/log"
)
type svc struct {
@ -72,14 +71,15 @@ func (s *svc) Logs(service *runtime.Service, opts ...runtime.LogsOption) (runtim
for _, o := range opts {
o(&options)
}
if options.Context == nil {
options.Context = context.Background()
}
ls, err := s.runtime.Logs(options.Context, &pb.LogsRequest{
Service: service.Name,
Stream: true,
Count: 10, // @todo pass in actual options
Stream: options.Stream,
Count: options.Count,
})
if err != nil {
return nil, err
@ -89,14 +89,39 @@ func (s *svc) Logs(service *runtime.Service, opts ...runtime.LogsOption) (runtim
stream: make(chan runtime.LogRecord),
stop: make(chan bool),
}
go func() {
for {
record := runtime.LogRecord{}
err := ls.RecvMsg(&record)
if err != nil {
log.Error(err)
select {
// @todo this never seems to return, investigate
case <-ls.Context().Done():
logStream.Stop()
}
}
}()
go func() {
for {
select {
// @todo this never seems to return, investigate
case <-ls.Context().Done():
return
case _, ok := <-logStream.stream:
if !ok {
return
}
default:
record := pb.LogRecord{}
err := ls.RecvMsg(&record)
if err != nil {
logStream.Stop()
return
}
logStream.stream <- runtime.LogRecord{
Message: record.GetMessage(),
Metadata: record.GetMetadata(),
}
}
logStream.stream <- record
}
}()
return logStream, nil
@ -125,6 +150,7 @@ func (l *serviceLogStream) Stop() error {
case <-l.stop:
return nil
default:
close(l.stream)
close(l.stop)
}
return nil

View File

@ -178,10 +178,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
}
// There is an account, set it in the context
ctx, err = auth.ContextWithAccount(ctx, account)
if err != nil {
return err
}
ctx = auth.ContextWithAccount(ctx, account)
// The user is authorised, allow the call
return h(ctx, req, rsp)

View File

@ -111,6 +111,9 @@ func (s *service) run(exit chan bool) {
}
func (s *service) register() error {
s.RLock()
defer s.RUnlock()
if s.srv == nil {
return nil
}
@ -138,6 +141,9 @@ func (s *service) register() error {
}
func (s *service) deregister() error {
s.RLock()
defer s.RUnlock()
if s.srv == nil {
return nil
}
@ -280,18 +286,22 @@ func (s *service) Client() *http.Client {
func (s *service) Handle(pattern string, handler http.Handler) {
var seen bool
s.RLock()
for _, ep := range s.srv.Endpoints {
if ep.Name == pattern {
seen = true
break
}
}
s.RUnlock()
// if its unseen then add an endpoint
if !seen {
s.Lock()
s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
Name: pattern,
})
s.Unlock()
}
// disable static serving
@ -306,17 +316,23 @@ func (s *service) Handle(pattern string, handler http.Handler) {
}
func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
var seen bool
s.RLock()
for _, ep := range s.srv.Endpoints {
if ep.Name == pattern {
seen = true
break
}
}
s.RUnlock()
if !seen {
s.Lock()
s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
Name: pattern,
})
s.Unlock()
}
// disable static serving
@ -331,7 +347,6 @@ func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *
func (s *service) Init(opts ...Option) error {
s.Lock()
defer s.Unlock()
for _, o := range opts {
o(&s.opts)
@ -347,6 +362,8 @@ func (s *service) Init(opts ...Option) error {
serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
}
s.Unlock()
serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error {
s.Lock()
defer s.Unlock()
@ -386,14 +403,19 @@ func (s *service) Init(opts ...Option) error {
return nil
}))
s.RLock()
// pass in own name and version
serviceOpts = append(serviceOpts, micro.Name(s.opts.Name))
serviceOpts = append(serviceOpts, micro.Version(s.opts.Version))
s.RUnlock()
s.opts.Service.Init(serviceOpts...)
s.Lock()
srv := s.genSrv()
srv.Endpoints = s.srv.Endpoints
s.srv = srv
s.Unlock()
return nil
}

72
web/web_test.go Normal file
View File

@ -0,0 +1,72 @@
package web_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/micro/cli/v2"
"github.com/micro/go-micro/v2"
"github.com/micro/go-micro/v2/logger"
"github.com/micro/go-micro/v2/web"
)
func TestWeb(t *testing.T) {
for i := 0; i < 10; i++ {
fmt.Println("Test nr", i)
testFunc()
}
}
func testFunc() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*250)
defer cancel()
s := micro.NewService(
micro.Name("test"),
micro.Context(ctx),
micro.HandleSignal(false),
micro.Flags(
&cli.StringFlag{
Name: "test.timeout",
},
&cli.BoolFlag{
Name: "test.v",
},
&cli.StringFlag{
Name: "test.run",
},
&cli.StringFlag{
Name: "test.testlogfile",
},
),
)
w := web.NewService(
web.MicroService(s),
web.Context(ctx),
web.HandleSignal(false),
)
//s.Init()
//w.Init()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err := s.Run()
if err != nil {
logger.Errorf("micro run error: %v", err)
}
}()
go func() {
defer wg.Done()
err := w.Run()
if err != nil {
logger.Errorf("web run error: %v", err)
}
}()
wg.Wait()
}