diff --git a/http.go b/http.go new file mode 100644 index 0000000..ec7b520 --- /dev/null +++ b/http.go @@ -0,0 +1,602 @@ +package http + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/micro/go-micro/v3/network/transport" + maddr "github.com/micro/go-micro/v3/util/addr" + "github.com/micro/go-micro/v3/util/buf" + mnet "github.com/micro/go-micro/v3/util/net" + mls "github.com/micro/go-micro/v3/util/tls" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +type httpTransport struct { + opts transport.Options +} + +type httpTransportClient struct { + ht *httpTransport + addr string + conn net.Conn + dialOpts transport.DialOptions + once sync.Once + + sync.RWMutex + + // request must be stored for response processing + r chan *http.Request + bl []*http.Request + buff *bufio.Reader + + // local/remote ip + local string + remote string +} + +type httpTransportSocket struct { + ht *httpTransport + w http.ResponseWriter + r *http.Request + rw *bufio.ReadWriter + + mtx sync.RWMutex + + // the hijacked when using http 1 + conn net.Conn + // for the first request + ch chan *http.Request + + // h2 things + buf *bufio.Reader + // indicate if socket is closed + closed chan bool + + // local/remote ip + local string + remote string +} + +type httpTransportListener struct { + ht *httpTransport + listener net.Listener +} + +func (h *httpTransportClient) Local() string { + return h.local +} + +func (h *httpTransportClient) Remote() string { + return h.remote +} + +func (h *httpTransportClient) Send(m *transport.Message) error { + header := make(http.Header) + + for k, v := range m.Header { + header.Set(k, v) + } + + b := buf.New(bytes.NewBuffer(m.Body)) + defer b.Close() + + req := &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: h.addr, + }, + Header: header, + Body: b, + ContentLength: int64(b.Len()), + Host: h.addr, + } + + h.Lock() + h.bl = append(h.bl, req) + select { + case h.r <- h.bl[0]: + h.bl = h.bl[1:] + default: + } + h.Unlock() + + // set timeout if its greater than 0 + if h.ht.opts.Timeout > time.Duration(0) { + h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) + } + + return req.Write(h.conn) +} + +func (h *httpTransportClient) Recv(m *transport.Message) error { + if m == nil { + return errors.New("message passed in is nil") + } + + var r *http.Request + if !h.dialOpts.Stream { + rc, ok := <-h.r + if !ok { + return io.EOF + } + r = rc + } + + // set timeout if its greater than 0 + if h.ht.opts.Timeout > time.Duration(0) { + h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) + } + + rsp, err := http.ReadResponse(h.buff, r) + if err != nil { + return err + } + defer rsp.Body.Close() + + b, err := ioutil.ReadAll(rsp.Body) + if err != nil { + return err + } + + if rsp.StatusCode != 200 { + return errors.New(rsp.Status + ": " + string(b)) + } + + m.Body = b + + if m.Header == nil { + m.Header = make(map[string]string, len(rsp.Header)) + } + + for k, v := range rsp.Header { + if len(v) > 0 { + m.Header[k] = v[0] + } else { + m.Header[k] = "" + } + } + + return nil +} + +func (h *httpTransportClient) Close() error { + h.once.Do(func() { + h.Lock() + h.buff.Reset(nil) + h.Unlock() + close(h.r) + }) + return h.conn.Close() +} + +func (h *httpTransportSocket) Local() string { + return h.local +} + +func (h *httpTransportSocket) Remote() string { + return h.remote +} + +func (h *httpTransportSocket) Recv(m *transport.Message) error { + if m == nil { + return errors.New("message passed in is nil") + } + if m.Header == nil { + m.Header = make(map[string]string, len(h.r.Header)) + } + + // process http 1 + if h.r.ProtoMajor == 1 { + // set timeout if its greater than 0 + if h.ht.opts.Timeout > time.Duration(0) { + h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) + } + + var r *http.Request + + select { + // get first request + case r = <-h.ch: + // read next request + default: + rr, err := http.ReadRequest(h.rw.Reader) + if err != nil { + return err + } + r = rr + } + + // read body + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + + // set body + r.Body.Close() + m.Body = b + + // set headers + for k, v := range r.Header { + if len(v) > 0 { + m.Header[k] = v[0] + } else { + m.Header[k] = "" + } + } + + // return early early + return nil + } + + // only process if the socket is open + select { + case <-h.closed: + return io.EOF + default: + // no op + } + + // processing http2 request + // read streaming body + + // set max buffer size + // TODO: adjustable buffer size + buf := make([]byte, 4*1024*1024) + + // read the request body + n, err := h.buf.Read(buf) + // not an eof error + if err != nil { + return err + } + + // check if we have data + if n > 0 { + m.Body = buf[:n] + } + + // set headers + for k, v := range h.r.Header { + if len(v) > 0 { + m.Header[k] = v[0] + } else { + m.Header[k] = "" + } + } + + // set path + m.Header[":path"] = h.r.URL.Path + + return nil +} + +func (h *httpTransportSocket) Send(m *transport.Message) error { + if h.r.ProtoMajor == 1 { + // make copy of header + hdr := make(http.Header) + for k, v := range h.r.Header { + hdr[k] = v + } + + rsp := &http.Response{ + Header: hdr, + Body: ioutil.NopCloser(bytes.NewReader(m.Body)), + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: int64(len(m.Body)), + } + + for k, v := range m.Header { + rsp.Header.Set(k, v) + } + + // set timeout if its greater than 0 + if h.ht.opts.Timeout > time.Duration(0) { + h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) + } + + return rsp.Write(h.conn) + } + + // only process if the socket is open + select { + case <-h.closed: + return io.EOF + default: + // no op + } + + // we need to lock to protect the write + h.mtx.RLock() + defer h.mtx.RUnlock() + + // set headers + for k, v := range m.Header { + h.w.Header().Set(k, v) + } + + // write request + _, err := h.w.Write(m.Body) + + // flush the trailers + h.w.(http.Flusher).Flush() + + return err +} + +func (h *httpTransportSocket) error(m *transport.Message) error { + if h.r.ProtoMajor == 1 { + rsp := &http.Response{ + Header: make(http.Header), + Body: ioutil.NopCloser(bytes.NewReader(m.Body)), + Status: "500 Internal Server Error", + StatusCode: 500, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: int64(len(m.Body)), + } + + for k, v := range m.Header { + rsp.Header.Set(k, v) + } + + return rsp.Write(h.conn) + } + + return nil +} + +func (h *httpTransportSocket) Close() error { + h.mtx.Lock() + defer h.mtx.Unlock() + select { + case <-h.closed: + return nil + default: + // close the channel + close(h.closed) + + // close the buffer + h.r.Body.Close() + + // close the connection + if h.r.ProtoMajor == 1 { + return h.conn.Close() + } + } + + return nil +} + +func (h *httpTransportListener) Addr() string { + return h.listener.Addr().String() +} + +func (h *httpTransportListener) Close() error { + return h.listener.Close() +} + +func (h *httpTransportListener) Accept(fn func(transport.Socket)) error { + // create handler mux + mux := http.NewServeMux() + + // register our transport handler + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + var buf *bufio.ReadWriter + var con net.Conn + + // read a regular request + if r.ProtoMajor == 1 { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + r.Body = ioutil.NopCloser(bytes.NewReader(b)) + // hijack the conn + hj, ok := w.(http.Hijacker) + if !ok { + // we're screwed + http.Error(w, "cannot serve conn", http.StatusInternalServerError) + return + } + + conn, bufrw, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + buf = bufrw + con = conn + } + + // buffered reader + bufr := bufio.NewReader(r.Body) + + // save the request + ch := make(chan *http.Request, 1) + ch <- r + + // create a new transport socket + sock := &httpTransportSocket{ + ht: h.ht, + w: w, + r: r, + rw: buf, + buf: bufr, + ch: ch, + conn: con, + local: h.Addr(), + remote: r.RemoteAddr, + closed: make(chan bool), + } + + // execute the socket + fn(sock) + }) + + // get optional handlers + if h.ht.opts.Context != nil { + handlers, ok := h.ht.opts.Context.Value("http_handlers").(map[string]http.Handler) + if ok { + for pattern, handler := range handlers { + mux.Handle(pattern, handler) + } + } + } + + // default http2 server + srv := &http.Server{ + Handler: mux, + } + + // insecure connection use h2c + if !(h.ht.opts.Secure || h.ht.opts.TLSConfig != nil) { + srv.Handler = h2c.NewHandler(mux, &http2.Server{}) + } + + // begin serving + return srv.Serve(h.listener) +} + +func (h *httpTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { + dopts := transport.DialOptions{ + Timeout: transport.DefaultDialTimeout, + } + + for _, opt := range opts { + opt(&dopts) + } + + var conn net.Conn + var err error + + // TODO: support dial option here rather than using internal config + if h.opts.Secure || h.opts.TLSConfig != nil { + config := h.opts.TLSConfig + if config == nil { + config = &tls.Config{ + InsecureSkipVerify: true, + } + } + config.NextProtos = []string{"http/1.1"} + conn, err = newConn(func(addr string) (net.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: dopts.Timeout}, "tcp", addr, config) + })(addr) + } else { + conn, err = newConn(func(addr string) (net.Conn, error) { + return net.DialTimeout("tcp", addr, dopts.Timeout) + })(addr) + } + + if err != nil { + return nil, err + } + + return &httpTransportClient{ + ht: h, + addr: addr, + conn: conn, + buff: bufio.NewReader(conn), + dialOpts: dopts, + r: make(chan *http.Request, 1), + local: conn.LocalAddr().String(), + remote: conn.RemoteAddr().String(), + }, nil +} + +func (h *httpTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) { + var options transport.ListenOptions + for _, o := range opts { + o(&options) + } + + var l net.Listener + var err error + + // TODO: support use of listen options + if h.opts.Secure || h.opts.TLSConfig != nil { + config := h.opts.TLSConfig + + fn := func(addr string) (net.Listener, error) { + if config == nil { + hosts := []string{addr} + + // check if its a valid host:port + if host, _, err := net.SplitHostPort(addr); err == nil { + if len(host) == 0 { + hosts = maddr.IPs() + } else { + hosts = []string{host} + } + } + + // generate a certificate + cert, err := mls.Certificate(hosts...) + if err != nil { + return nil, err + } + config = &tls.Config{Certificates: []tls.Certificate{cert}} + } + return tls.Listen("tcp", addr, config) + } + + l, err = mnet.Listen(addr, fn) + } else { + fn := func(addr string) (net.Listener, error) { + return net.Listen("tcp", addr) + } + + l, err = mnet.Listen(addr, fn) + } + + if err != nil { + return nil, err + } + + return &httpTransportListener{ + ht: h, + listener: l, + }, nil +} + +func (h *httpTransport) Init(opts ...transport.Option) error { + for _, o := range opts { + o(&h.opts) + } + return nil +} + +func (h *httpTransport) Options() transport.Options { + return h.opts +} + +func (h *httpTransport) String() string { + return "http" +} + +func NewTransport(opts ...transport.Option) transport.Transport { + var options transport.Options + for _, o := range opts { + o(&options) + } + return &httpTransport{opts: options} +} diff --git a/http_proxy.go b/http_proxy.go new file mode 100644 index 0000000..328091b --- /dev/null +++ b/http_proxy.go @@ -0,0 +1,109 @@ +package http + +import ( + "bufio" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" +) + +const ( + proxyAuthHeader = "Proxy-Authorization" +) + +func getURL(addr string) (*url.URL, error) { + r := &http.Request{ + URL: &url.URL{ + Scheme: "https", + Host: addr, + }, + } + return http.ProxyFromEnvironment(r) +} + +type pbuffer struct { + net.Conn + r io.Reader +} + +func (p *pbuffer) Read(b []byte) (int, error) { + return p.r.Read(b) +} + +func proxyDial(conn net.Conn, addr string, proxyURL *url.URL) (_ net.Conn, err error) { + defer func() { + if err != nil { + conn.Close() + } + }() + + r := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Header: map[string][]string{"User-Agent": {"micro/latest"}}, + } + + if user := proxyURL.User; user != nil { + u := user.Username() + p, _ := user.Password() + auth := []byte(u + ":" + p) + basicAuth := base64.StdEncoding.EncodeToString(auth) + r.Header.Add(proxyAuthHeader, "Basic "+basicAuth) + } + + if err := r.Write(conn); err != nil { + return nil, fmt.Errorf("failed to write the HTTP request: %v", err) + } + + br := bufio.NewReader(conn) + rsp, err := http.ReadResponse(br, r) + if err != nil { + return nil, fmt.Errorf("reading server HTTP response: %v", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(rsp, true) + if err != nil { + return nil, fmt.Errorf("failed to do connect handshake, status code: %s", rsp.Status) + } + return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) + } + + return &pbuffer{Conn: conn, r: br}, nil +} + +// Creates a new connection +func newConn(dial func(string) (net.Conn, error)) func(string) (net.Conn, error) { + return func(addr string) (net.Conn, error) { + // get the proxy url + proxyURL, err := getURL(addr) + if err != nil { + return nil, err + } + + // set to addr + callAddr := addr + + // got proxy + if proxyURL != nil { + callAddr = proxyURL.Host + } + + // dial the addr + c, err := dial(callAddr) + if err != nil { + return nil, err + } + + // do proxy connect if we have proxy url + if proxyURL != nil { + c, err = proxyDial(c, addr, proxyURL) + } + + return c, err + } +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..2a40ba0 --- /dev/null +++ b/http_test.go @@ -0,0 +1,138 @@ +package http + +import ( + "sync" + "testing" + + "github.com/micro/go-micro/v3/network/transport" +) + +func call(b *testing.B, c int) { + b.StopTimer() + + tr := NewTransport() + + // server listen + l, err := tr.Listen("localhost:0") + if err != nil { + b.Fatal(err) + } + defer l.Close() + + // socket func + fn := func(sock transport.Socket) { + defer sock.Close() + + for { + var m transport.Message + if err := sock.Recv(&m); err != nil { + return + } + + if err := sock.Send(&m); err != nil { + return + } + } + } + + done := make(chan bool) + + // accept connections + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + b.Fatalf("Unexpected accept err: %v", err) + } + } + }() + + m := transport.Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + + // client connection + client, err := tr.Dial(l.Addr()) + if err != nil { + b.Fatalf("Unexpected dial err: %v", err) + } + + send := func(c transport.Client) { + // send message + if err := c.Send(&m); err != nil { + b.Fatalf("Unexpected send err: %v", err) + } + + var rm transport.Message + // receive message + if err := c.Recv(&rm); err != nil { + b.Fatalf("Unexpected recv err: %v", err) + } + } + + // warm + for i := 0; i < 10; i++ { + send(client) + } + + client.Close() + + ch := make(chan int, c*4) + + var wg sync.WaitGroup + wg.Add(c) + + for i := 0; i < c; i++ { + go func() { + cl, err := tr.Dial(l.Addr()) + if err != nil { + b.Fatalf("Unexpected dial err: %v", err) + } + defer cl.Close() + + for range ch { + send(cl) + } + + wg.Done() + }() + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + ch <- i + } + + b.StopTimer() + close(ch) + + wg.Wait() + + // finish + close(done) +} + +func BenchmarkTransport1(b *testing.B) { + call(b, 1) +} + +func BenchmarkTransport8(b *testing.B) { + call(b, 8) +} + +func BenchmarkTransport16(b *testing.B) { + call(b, 16) +} + +func BenchmarkTransport64(b *testing.B) { + call(b, 64) +} + +func BenchmarkTransport128(b *testing.B) { + call(b, 128) +} diff --git a/http_transport_test.go b/http_transport_test.go new file mode 100644 index 0000000..c5711d7 --- /dev/null +++ b/http_transport_test.go @@ -0,0 +1,248 @@ +package http + +import ( + "io" + "net" + "testing" + "time" + + "github.com/micro/go-micro/v3/network/transport" +) + +func expectedPort(t *testing.T, expected string, lsn transport.Listener) { + _, port, err := net.SplitHostPort(lsn.Addr()) + if err != nil { + t.Errorf("Expected address to be `%s`, got error: %v", expected, err) + } + + if port != expected { + lsn.Close() + t.Errorf("Expected address to be `%s`, got `%s`", expected, port) + } +} + +func TestHTTPTransportPortRange(t *testing.T) { + tp := NewTransport() + + lsn1, err := tp.Listen(":44444-44448") + if err != nil { + t.Errorf("Did not expect an error, got %s", err) + } + expectedPort(t, "44444", lsn1) + + lsn2, err := tp.Listen(":44444-44448") + if err != nil { + t.Errorf("Did not expect an error, got %s", err) + } + expectedPort(t, "44445", lsn2) + + lsn, err := tp.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Did not expect an error, got %s", err) + } + + lsn.Close() + lsn1.Close() + lsn2.Close() +} + +func TestHTTPTransportCommunication(t *testing.T) { + tr := NewTransport() + + l, err := tr.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Unexpected listen err: %v", err) + } + defer l.Close() + + fn := func(sock transport.Socket) { + defer sock.Close() + + for { + var m transport.Message + if err := sock.Recv(&m); err != nil { + return + } + + if err := sock.Send(&m); err != nil { + return + } + } + } + + done := make(chan bool) + + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + t.Errorf("Unexpected accept err: %v", err) + } + } + }() + + c, err := tr.Dial(l.Addr()) + if err != nil { + t.Errorf("Unexpected dial err: %v", err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + + if err := c.Send(&m); err != nil { + t.Errorf("Unexpected send err: %v", err) + } + + var rm transport.Message + + if err := c.Recv(&rm); err != nil { + t.Errorf("Unexpected recv err: %v", err) + } + + if string(rm.Body) != string(m.Body) { + t.Errorf("Expected %v, got %v", m.Body, rm.Body) + } + + close(done) +} + +func TestHTTPTransportError(t *testing.T) { + tr := NewTransport() + + l, err := tr.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Unexpected listen err: %v", err) + } + defer l.Close() + + fn := func(sock transport.Socket) { + defer sock.Close() + + for { + var m transport.Message + if err := sock.Recv(&m); err != nil { + if err == io.EOF { + return + } + t.Fatal(err) + } + + sock.(*httpTransportSocket).error(&transport.Message{ + Body: []byte(`an error occurred`), + }) + } + } + + done := make(chan bool) + + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + t.Errorf("Unexpected accept err: %v", err) + } + } + }() + + c, err := tr.Dial(l.Addr()) + if err != nil { + t.Errorf("Unexpected dial err: %v", err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + + if err := c.Send(&m); err != nil { + t.Errorf("Unexpected send err: %v", err) + } + + var rm transport.Message + + err = c.Recv(&rm) + if err == nil { + t.Fatal("Expected error but got nil") + } + + if err.Error() != "500 Internal Server Error: an error occurred" { + t.Fatalf("Did not receive expected error, got: %v", err) + } + + close(done) +} + +func TestHTTPTransportTimeout(t *testing.T) { + tr := NewTransport(transport.Timeout(time.Millisecond * 100)) + + l, err := tr.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Unexpected listen err: %v", err) + } + defer l.Close() + + done := make(chan bool) + + fn := func(sock transport.Socket) { + defer func() { + sock.Close() + close(done) + }() + + go func() { + select { + case <-done: + return + case <-time.After(time.Second): + t.Fatal("deadline not executed") + } + }() + + for { + var m transport.Message + + if err := sock.Recv(&m); err != nil { + return + } + } + } + + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + t.Errorf("Unexpected accept err: %v", err) + } + } + }() + + c, err := tr.Dial(l.Addr()) + if err != nil { + t.Errorf("Unexpected dial err: %v", err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + + if err := c.Send(&m); err != nil { + t.Errorf("Unexpected send err: %v", err) + } + + <-done +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..0a9f9a2 --- /dev/null +++ b/options.go @@ -0,0 +1,23 @@ +package http + +import ( + "context" + "net/http" + + "github.com/micro/go-micro/v3/network/transport" +) + +// Handle registers the handler for the given pattern. +func Handle(pattern string, handler http.Handler) transport.Option { + return func(o *transport.Options) { + if o.Context == nil { + o.Context = context.Background() + } + handlers, ok := o.Context.Value("http_handlers").(map[string]http.Handler) + if !ok { + handlers = make(map[string]http.Handler) + } + handlers[pattern] = handler + o.Context = context.WithValue(o.Context, "http_handlers", handlers) + } +}