Refactor to use publicsuffix

This commit is contained in:
Ben Toogood
2020-04-07 10:28:39 +01:00
parent 11e1e9120a
commit 501fc5c059
3 changed files with 27 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/micro/go-micro/v2/api/resolver/path"
"github.com/micro/go-micro/v2/auth"
"github.com/micro/go-micro/v2/logger"
"golang.org/x/net/publicsuffix"
)
// CombinedAuthHandler wraps a server and authenticates requests
@@ -39,7 +40,7 @@ type authHandler struct {
func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Determine the namespace and set it in the header
namespace := h.namespaceFromRequest(req)
namespace := h.NamespaceFromRequest(req)
req.Header.Set(auth.NamespaceKey, namespace)
// Extract the token from the request
@@ -131,7 +132,7 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, loginWithRedirect, http.StatusTemporaryRedirect)
}
func (h authHandler) namespaceFromRequest(req *http.Request) string {
func (h authHandler) NamespaceFromRequest(req *http.Request) string {
// check to see what the provided namespace is, we only do
// domain mapping if the namespace is set to 'domain'
if h.namespace != "domain" {
@@ -161,18 +162,27 @@ func (h authHandler) namespaceFromRequest(req *http.Request) string {
return auth.DefaultNamespace
}
// TODO: this logic needs to be replaced with usage of publicsuffix
// if host is not a subdomain, deturn default namespace
comps := strings.Split(host, ".")
if len(comps) < 3 {
// extract the top level domain plus one (e.g. 'myapp.com')
domain, err := publicsuffix.EffectiveTLDPlusOne(host)
if err != nil {
logger.Debugf("Unable to extract domain from %v", host)
return auth.DefaultNamespace
}
// return the reversed subdomain as the namespace
nComps := comps[0 : len(comps)-2]
for i := len(nComps)/2 - 1; i >= 0; i-- {
opp := len(nComps) - 1 - i
nComps[i], nComps[opp] = nComps[opp], nComps[i]
// check to see if the domain is the host, in this
// case we return the default namespace
if domain == host {
return auth.DefaultNamespace
}
return strings.Join(nComps, ".")
// remove the domain from the host, leaving the subdomain
subdomain := strings.TrimSuffix(host, "."+domain)
// return the reversed subdomain as the namespace
comps := strings.Split(subdomain, ".")
for i := len(comps)/2 - 1; i >= 0; i-- {
opp := len(comps) - 1 - i
comps[i], comps[opp] = comps[opp], comps[i]
}
return strings.Join(comps, ".")
}

View File

@@ -13,6 +13,7 @@ func TestNamespaceFromRequest(t *testing.T) {
Namespace string
}{
{Host: "micro.mu", Namespace: auth.DefaultNamespace},
{Host: "micro.com.au", Namespace: auth.DefaultNamespace},
{Host: "web.micro.mu", Namespace: auth.DefaultNamespace},
{Host: "api.micro.mu", Namespace: auth.DefaultNamespace},
{Host: "myapp.com", Namespace: auth.DefaultNamespace},
@@ -23,9 +24,11 @@ func TestNamespaceFromRequest(t *testing.T) {
{Host: "81.151.101.146", Namespace: auth.DefaultNamespace},
}
h := &authHandler{namespace: "domain"}
for _, tc := range tt {
t.Run(tc.Host, func(t *testing.T) {
ns := namespaceFromRequest(&http.Request{Host: tc.Host})
ns := h.NamespaceFromRequest(&http.Request{Host: tc.Host})
if ns != tc.Namespace {
t.Errorf("Expected namespace %v for host %v, actually got %v", tc.Namespace, tc.Host, ns)
}