integrate request builder into HTTP client for googleapis support (#157)
This commit is contained in:
116
builder/body.go
Normal file
116
builder/body.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
func buildSingleFieldBody(msg proto.Message, fieldName string) (proto.Message, error) {
|
||||
msgReflect := msg.ProtoReflect()
|
||||
|
||||
fd, found := findFieldByName(msgReflect, fieldName)
|
||||
if !found || fd == nil {
|
||||
return nil, fmt.Errorf("field %s not found", fieldName)
|
||||
}
|
||||
if !msgReflect.Has(fd) {
|
||||
return nil, fmt.Errorf("field %s is not set", fieldName)
|
||||
}
|
||||
|
||||
val := msgReflect.Get(fd)
|
||||
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
return val.Message().Interface(), nil
|
||||
}
|
||||
|
||||
newMsg := proto.Clone(msg)
|
||||
newMsgReflect := newMsg.ProtoReflect()
|
||||
newMsgReflect.Range(func(f protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
||||
if f != fd {
|
||||
newMsgReflect.Clear(f)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return newMsg, nil
|
||||
}
|
||||
|
||||
func buildFullBody(msg proto.Message, usedFieldsPath *usedFields) (proto.Message, error) {
|
||||
var (
|
||||
msgReflect = msg.ProtoReflect()
|
||||
newMsg = msgReflect.New().Interface()
|
||||
newMsgReflect = newMsg.ProtoReflect()
|
||||
)
|
||||
|
||||
fields := msgReflect.Descriptor().Fields()
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
fd := fields.Get(i)
|
||||
fieldName := fd.JSONName()
|
||||
|
||||
if usedFieldsPath.hasTopLevelKey(fieldName) {
|
||||
continue
|
||||
}
|
||||
|
||||
val := msgReflect.Get(fd)
|
||||
if !val.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Note: order of the cases is important!
|
||||
switch {
|
||||
case fd.IsList():
|
||||
list := val.List()
|
||||
newList := newMsgReflect.Mutable(fd).List()
|
||||
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
elem, err := buildFullBody(list.Get(j).Message().Interface(), usedFieldsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("recursive build full body: %w", err)
|
||||
}
|
||||
newList.Append(protoreflect.ValueOfMessage(elem.ProtoReflect()))
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
newList.Append(list.Get(j))
|
||||
}
|
||||
}
|
||||
|
||||
case fd.IsMap():
|
||||
var (
|
||||
m = val.Map()
|
||||
newMap = newMsgReflect.Mutable(fd).Map()
|
||||
rangeErr error
|
||||
)
|
||||
m.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||
if fd.MapValue().Kind() == protoreflect.MessageKind {
|
||||
elem, err := buildFullBody(v.Message().Interface(), usedFieldsPath)
|
||||
if err != nil {
|
||||
rangeErr = fmt.Errorf("recursive build full body: %w", err)
|
||||
return false
|
||||
}
|
||||
newMap.Set(k, protoreflect.ValueOfMessage(elem.ProtoReflect()))
|
||||
} else {
|
||||
newMap.Set(k, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if rangeErr != nil {
|
||||
return nil, fmt.Errorf("map range error: %w", rangeErr)
|
||||
}
|
||||
|
||||
case fd.Kind() == protoreflect.MessageKind:
|
||||
elem, err := buildFullBody(val.Message().Interface(), usedFieldsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("recursive build full body: %w", err)
|
||||
}
|
||||
newMsgReflect.Set(fd, protoreflect.ValueOfMessage(elem.ProtoReflect()))
|
||||
|
||||
default:
|
||||
newMsgReflect.Set(fd, val)
|
||||
}
|
||||
}
|
||||
|
||||
return newMsg, nil
|
||||
}
|
22
builder/body_option.go
Normal file
22
builder/body_option.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package builder
|
||||
|
||||
const (
|
||||
singleWildcard string = "*"
|
||||
doubleWildcard string = "**"
|
||||
)
|
||||
|
||||
type bodyOption string
|
||||
|
||||
func (o bodyOption) String() string { return string(o) }
|
||||
|
||||
func (o bodyOption) isFullBody() bool {
|
||||
return o.String() == singleWildcard
|
||||
}
|
||||
|
||||
func (o bodyOption) isWithoutBody() bool {
|
||||
return o == ""
|
||||
}
|
||||
|
||||
func (o bodyOption) isSingleField() bool {
|
||||
return o != "" && o.String() != singleWildcard
|
||||
}
|
79
builder/body_option_test.go
Normal file
79
builder/body_option_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBodyOption_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt bodyOption
|
||||
want string
|
||||
}{
|
||||
{"empty", bodyOption(""), ""},
|
||||
{"star", bodyOption("*"), "*"},
|
||||
{"field", bodyOption("field"), "field"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.opt.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyOption_isFullBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt bodyOption
|
||||
want bool
|
||||
}{
|
||||
{"empty", bodyOption(""), false},
|
||||
{"star", bodyOption("*"), true},
|
||||
{"field", bodyOption("field"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.opt.isFullBody())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyOption_isWithoutBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt bodyOption
|
||||
want bool
|
||||
}{
|
||||
{"empty", bodyOption(""), true},
|
||||
{"star", bodyOption("*"), false},
|
||||
{"field", bodyOption("field"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.opt.isWithoutBody())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyOption_isSingleField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt bodyOption
|
||||
want bool
|
||||
}{
|
||||
{"empty", bodyOption(""), false},
|
||||
{"star", bodyOption("*"), false},
|
||||
{"field", bodyOption("field"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, tt.opt.isSingleField())
|
||||
})
|
||||
}
|
||||
}
|
114
builder/helpers.go
Normal file
114
builder/helpers.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
// findFieldByPath resolves a dot-separated field path in a protobuf message and returns the protoreflect value and its descriptor.
|
||||
func findFieldByPath(msg protoreflect.Message, fieldPath string) (protoreflect.Value, protoreflect.FieldDescriptor, bool) {
|
||||
var (
|
||||
current = msg
|
||||
parts = strings.Split(fieldPath, ".")
|
||||
partsCount = len(parts) - 1
|
||||
)
|
||||
|
||||
for i, part := range parts {
|
||||
fd, ok := findFieldByName(current, part)
|
||||
if !ok {
|
||||
return protoreflect.Value{}, nil, false
|
||||
}
|
||||
|
||||
val := current.Get(fd)
|
||||
if i == partsCount { // it's last part
|
||||
return val, fd, true
|
||||
}
|
||||
|
||||
if fd.Kind() != protoreflect.MessageKind {
|
||||
return protoreflect.Value{}, nil, false
|
||||
}
|
||||
current = val.Message()
|
||||
}
|
||||
|
||||
return protoreflect.Value{}, nil, false
|
||||
}
|
||||
|
||||
// findFieldByName find a field name in a protobuf message and returns the protoreflect field descriptor.
|
||||
func findFieldByName(msg protoreflect.Message, fieldName string) (protoreflect.FieldDescriptor, bool) {
|
||||
fields := msg.Descriptor().Fields()
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
fd := fields.Get(i)
|
||||
if fd.JSONName() == fieldName {
|
||||
return fd, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// isZeroValue checks if protoreflect.Value is zero for the field.
|
||||
func isZeroValue(val protoreflect.Value, fd protoreflect.FieldDescriptor) bool {
|
||||
if fd.IsList() {
|
||||
return val.List().Len() == 0
|
||||
}
|
||||
if fd.IsMap() {
|
||||
return val.Map().Len() == 0
|
||||
}
|
||||
|
||||
switch fd.Kind() {
|
||||
case protoreflect.StringKind:
|
||||
return val.String() == ""
|
||||
case protoreflect.BytesKind:
|
||||
return len(val.Bytes()) == 0
|
||||
case protoreflect.BoolKind:
|
||||
return !val.Bool()
|
||||
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
|
||||
protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
|
||||
return val.Int() == 0
|
||||
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
|
||||
protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
|
||||
return val.Uint() == 0
|
||||
case protoreflect.FloatKind, protoreflect.DoubleKind:
|
||||
return val.Float() == 0
|
||||
case protoreflect.EnumKind:
|
||||
return val.Enum() == 0
|
||||
case protoreflect.MessageKind:
|
||||
return !val.Message().IsValid()
|
||||
default:
|
||||
return !val.IsValid()
|
||||
}
|
||||
}
|
||||
|
||||
// stringifyValue converts protoreflect.Value to string for path/query substitution.
|
||||
func stringifyValue(val protoreflect.Value, fd protoreflect.FieldDescriptor) (string, error) {
|
||||
switch fd.Kind() {
|
||||
case protoreflect.StringKind:
|
||||
return val.String(), nil
|
||||
case protoreflect.BytesKind:
|
||||
return base64.StdEncoding.EncodeToString(val.Bytes()), nil
|
||||
case protoreflect.BoolKind:
|
||||
if val.Bool() {
|
||||
return "true", nil
|
||||
}
|
||||
return "false", nil
|
||||
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
|
||||
protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
|
||||
return fmt.Sprintf("%d", val.Int()), nil
|
||||
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
|
||||
protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
|
||||
return fmt.Sprintf("%d", val.Uint()), nil
|
||||
case protoreflect.FloatKind, protoreflect.DoubleKind:
|
||||
return strconv.FormatFloat(val.Float(), 'g', -1, 64), nil
|
||||
case protoreflect.EnumKind:
|
||||
ed := fd.Enum().Values().ByNumber(val.Enum())
|
||||
if ed != nil {
|
||||
return string(ed.Name()), nil
|
||||
}
|
||||
return fmt.Sprintf("%d", val.Enum()), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported field kind: %s", fd.Kind())
|
||||
}
|
||||
}
|
313
builder/path_template.go
Normal file
313
builder/path_template.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
// -------------------------- Path template representation -------------------------
|
||||
|
||||
// pathSegment is a helper interface for elements of a path.
|
||||
type pathSegment interface {
|
||||
isSegment() bool
|
||||
}
|
||||
|
||||
// pathTemplate represents a parsed URL path template.
|
||||
type pathTemplate struct {
|
||||
// literalPrefix is the fixed part of the path before the first {var},
|
||||
// e.g. "/v1/users/" for "/v1/users/{user_id}/orders:get".
|
||||
// It is removed from segments, so segments contain only the remaining path literals and variables.
|
||||
literalPrefix string
|
||||
|
||||
// segments is a sequence of pathLiteral or pathVar representing the rest of the path after literalPrefix.
|
||||
segments []pathSegment
|
||||
|
||||
// customVerb is an optional ":verb" suffix, e.g. ":get".
|
||||
customVerb string
|
||||
}
|
||||
|
||||
// pathLiteral represents a fixed literal segment in a path template, e.g., "/v1/users/".
|
||||
type pathLiteral struct {
|
||||
text string
|
||||
}
|
||||
|
||||
func (p pathLiteral) isSegment() bool { return true }
|
||||
|
||||
// pathVar represents a variable segment in a path template, e.g., "{user.id}".
|
||||
type pathVar struct {
|
||||
// fieldPath is the dotted path to the field in the struct, e.g., "user.id".
|
||||
fieldPath string
|
||||
|
||||
// pattern is the optional pattern after '=', e.g., "*" or "**/orders".
|
||||
// It specifies how the variable can match parts of the URL path.
|
||||
pattern string
|
||||
|
||||
// multiSegment is true if the pattern can match multiple path segments
|
||||
// (contains '/' or "**").
|
||||
multiSegment bool
|
||||
}
|
||||
|
||||
func (p pathVar) isSegment() bool { return true }
|
||||
|
||||
// ----------------------------- Path template parsing -----------------------------
|
||||
|
||||
// parsePathTemplate parses a URL path template into a pathTemplate.
|
||||
// It extracts:
|
||||
// 1. literalPrefix — fixed part before the first variable,
|
||||
// 2. segments — sequence of pathLiteral and pathVar,
|
||||
// 3. customVerb — optional ":verb" suffix.
|
||||
//
|
||||
// Complexity: time O(n), memory O(n).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// input: "/v1/users/{user_id}/orders:get"
|
||||
// output: pathTemplate{
|
||||
// literalPrefix: "/v1/users/",
|
||||
// segments: [{user_id}, "/orders"],
|
||||
// customVerb: ":get",
|
||||
// }
|
||||
func parsePathTemplate(input string) (*pathTemplate, error) {
|
||||
// Step 1: extract custom verb after the last colon, e.g. ":get"
|
||||
var customVerb string
|
||||
if i := strings.LastIndex(input, ":"); i >= 0 && i > strings.LastIndex(input, "/") {
|
||||
customVerb = input[i:]
|
||||
input = input[:i]
|
||||
}
|
||||
|
||||
var (
|
||||
segments []pathSegment
|
||||
buf strings.Builder
|
||||
)
|
||||
|
||||
// Step 2: iterate over the input and split into segments
|
||||
for i := 0; i < len(input); {
|
||||
if input[i] != '{' {
|
||||
buf.WriteByte(input[i])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Add literal before '{' if any
|
||||
if buf.Len() > 0 {
|
||||
segments = append(segments, pathLiteral{text: buf.String()})
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
// Find closing '}'
|
||||
start := i + 1
|
||||
offset := strings.IndexByte(input[start:], '}') // relative offset from start
|
||||
if offset < 0 {
|
||||
return nil, fmt.Errorf("unclosed '{' in path: %s", input)
|
||||
}
|
||||
end := start + offset
|
||||
|
||||
token := input[start:end]
|
||||
i = end + 1 // jump past '}'
|
||||
|
||||
// Split field path and optional pattern
|
||||
var fieldPath, pattern string
|
||||
if k := strings.IndexByte(token, '='); k >= 0 {
|
||||
fieldPath = strings.TrimSpace(token[:k])
|
||||
pattern = strings.TrimSpace(token[k+1:])
|
||||
} else {
|
||||
fieldPath = strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
if fieldPath == "" {
|
||||
return nil, fmt.Errorf("empty variable in path: %s", input)
|
||||
}
|
||||
|
||||
pv := pathVar{
|
||||
fieldPath: fieldPath,
|
||||
pattern: pattern,
|
||||
multiSegment: isMultiSegmentPattern(pattern),
|
||||
}
|
||||
segments = append(segments, pv)
|
||||
}
|
||||
|
||||
// Step 3: add any trailing literal after last '}'
|
||||
if buf.Len() > 0 {
|
||||
segments = append(segments, pathLiteral{text: buf.String()})
|
||||
}
|
||||
|
||||
// Step 4: extract literalPrefix if the first segment is a literal
|
||||
var literalPrefix string
|
||||
if len(segments) > 0 {
|
||||
if pl, ok := segments[0].(pathLiteral); ok {
|
||||
literalPrefix = pl.text
|
||||
segments = segments[1:] // remove from segments to avoid duplication
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: return fully parsed pathTemplate
|
||||
return &pathTemplate{
|
||||
literalPrefix: literalPrefix,
|
||||
segments: segments,
|
||||
customVerb: customVerb,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isMultiSegmentPattern returns true if pattern can match multiple path segments (contains '/' or '**').
|
||||
// Examples:
|
||||
// | Pattern | Result | Usecase |
|
||||
// |----------------|--------|------------------------------------------|
|
||||
// | "" | false | {var} => single segment |
|
||||
// | "*" | false | {var=*} => single segment |
|
||||
// | "**" | true | {var=**} => multiple segments |
|
||||
// | "foo/*" | true | {var=foo/*} => multiple segments |
|
||||
// | "foo/**" | true | {var=foo/**} => multiple segments |
|
||||
// | "users/*/orders"| true | {users/*/orders} => multiple segments |
|
||||
func isMultiSegmentPattern(pattern string) bool {
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
if pattern == singleWildcard {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(pattern, "/") || strings.Contains(pattern, doubleWildcard)
|
||||
}
|
||||
|
||||
// ----------------------------- Path template resolving -----------------------------
|
||||
|
||||
// resolvePathPlaceholders expands placeholders in a path template using values from proto.Message.
|
||||
// Placeholders must be bound to non-repeated scalar fields (not lists, maps, or messages).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// tmpl: "/v1/users/{user_id}/orders:get"
|
||||
// msg: &pb.Message{UserId: 12345}
|
||||
//
|
||||
// path: "/v1/users/12345/orders:get"
|
||||
// usedFields: {"user_id"}
|
||||
func resolvePathPlaceholders(tmpl *pathTemplate, msg proto.Message) (path string, usedFields *usedFields, err error) {
|
||||
usedFields = newUsedFields()
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(tmpl.literalPrefix)
|
||||
|
||||
msgReflect := msg.ProtoReflect()
|
||||
|
||||
for _, segment := range tmpl.segments {
|
||||
switch s := segment.(type) {
|
||||
case pathLiteral:
|
||||
sb.WriteString(s.text)
|
||||
|
||||
case pathVar:
|
||||
val, fd, ok := findFieldByPath(msgReflect, s.fieldPath)
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("path placeholder %s not found", s.fieldPath)
|
||||
}
|
||||
if isZeroValue(val, fd) {
|
||||
// it's the only case that allows zero-value matches.
|
||||
if s.pattern == doubleWildcard {
|
||||
usedFields.add(s.fieldPath)
|
||||
continue
|
||||
}
|
||||
return "", nil, fmt.Errorf("path placeholder %s has zero value", s.fieldPath)
|
||||
}
|
||||
|
||||
// must be scalar (non-repeated, non-map, non-message)
|
||||
if fd.IsList() || fd.IsMap() || fd.Kind() == protoreflect.MessageKind {
|
||||
return "", nil, fmt.Errorf("path placeholder %s must be scalar", s.fieldPath)
|
||||
}
|
||||
|
||||
usedFields.add(s.fieldPath)
|
||||
|
||||
var strVal string
|
||||
strVal, err = stringifyValue(val, fd)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("stringify placeholder %s: %w", s.fieldPath, err)
|
||||
}
|
||||
|
||||
if err = validatePattern(s.pattern, strVal); err != nil {
|
||||
return "", nil, fmt.Errorf("validate pattern, %s:%s: %w", s.fieldPath, strVal, err)
|
||||
}
|
||||
|
||||
parts := strings.Split(strVal, "/")
|
||||
for i := range parts {
|
||||
parts[i] = url.PathEscape(parts[i])
|
||||
}
|
||||
sb.WriteString(strings.Join(parts, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(tmpl.customVerb)
|
||||
return sb.String(), usedFields, nil
|
||||
}
|
||||
|
||||
// validatePattern checks whether input matches the given path pattern.
|
||||
//
|
||||
// Rules:
|
||||
// - "" or "*" => exactly one segment, no "/" allowed
|
||||
// - "**" => zero or more segments (may include "/")
|
||||
// - composite patterns like "*/orders/*" must match literally
|
||||
//
|
||||
// Example for composite pattern case:
|
||||
//
|
||||
// pattern: "*/orders/*"
|
||||
// input: "42/orders/123"
|
||||
//
|
||||
// patternSegments = ["*", "orders", "*"]
|
||||
// valueParts = ["42", "orders", "123"]
|
||||
//
|
||||
// Match:
|
||||
// "*" -> "42"
|
||||
// "orders" -> "orders"
|
||||
// "*" -> "123"
|
||||
func validatePattern(pattern, input string) error {
|
||||
var (
|
||||
parts = strings.Split(input, "/")
|
||||
lenParts = len(parts)
|
||||
)
|
||||
|
||||
if pattern == "" || pattern == singleWildcard {
|
||||
if lenParts != 1 {
|
||||
return errors.New("must be a single path segment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if pattern == doubleWildcard {
|
||||
if lenParts < 1 {
|
||||
return errors.New("must contain at least one segment")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
patternSegments = strings.Split(pattern, "/")
|
||||
patternIndex int
|
||||
)
|
||||
|
||||
for i := 0; i < len(patternSegments); i++ {
|
||||
switch patternSegments[i] {
|
||||
case singleWildcard:
|
||||
if patternIndex >= lenParts || parts[patternIndex] == "" {
|
||||
return fmt.Errorf("segment %d must not be empty", patternIndex)
|
||||
}
|
||||
patternIndex++
|
||||
case doubleWildcard:
|
||||
if patternIndex >= lenParts {
|
||||
return fmt.Errorf("must contain at least one segment at position %d", patternIndex)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
if patternIndex >= lenParts || parts[patternIndex] != patternSegments[i] {
|
||||
return fmt.Errorf("expected literal %s at position %d", patternSegments[i], patternIndex)
|
||||
}
|
||||
patternIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if patternIndex != lenParts {
|
||||
return errors.New("extra segments in value")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
21
builder/path_template_cache.go
Normal file
21
builder/path_template_cache.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package builder
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
pathTemplateCache = make(map[string]*pathTemplate)
|
||||
pathTemplateCacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
func getCachedPathTemplate(path string) (*pathTemplate, bool) {
|
||||
pathTemplateCacheMu.RLock()
|
||||
defer pathTemplateCacheMu.RUnlock()
|
||||
tmpl, ok := pathTemplateCache[path]
|
||||
return tmpl, ok
|
||||
}
|
||||
|
||||
func setPathTemplateCache(path string, tmpl *pathTemplate) {
|
||||
pathTemplateCacheMu.Lock()
|
||||
defer pathTemplateCacheMu.Unlock()
|
||||
pathTemplateCache[path] = tmpl
|
||||
}
|
15
builder/proto/errors.pb.go
Normal file
15
builder/proto/errors.pb.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package proto
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var marshaler = protojson.MarshalOptions{}
|
||||
|
||||
func (m *Test_Client_Call_DefaultError) Error() string {
|
||||
buf, _ := marshaler.Marshal(m)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (m *Test_Client_Call_SpecialError) Error() string {
|
||||
buf, _ := marshaler.Marshal(m)
|
||||
return string(buf)
|
||||
}
|
3906
builder/proto/test_messages.pb.go
Normal file
3906
builder/proto/test_messages.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
330
builder/proto/test_messages.proto
Normal file
330
builder/proto/test_messages.proto
Normal file
@@ -0,0 +1,330 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package proto;
|
||||
|
||||
option go_package = "go.unistack.org/micro-client-http/v4/proto;proto";
|
||||
|
||||
message TestRequestBuilder {}
|
||||
|
||||
message Test_PathOnly {
|
||||
message PrimitiveCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
}
|
||||
|
||||
message NestedCase {
|
||||
User user = 1;
|
||||
Order order = 2;
|
||||
|
||||
message User {
|
||||
string id = 1;
|
||||
}
|
||||
message Order {
|
||||
int64 id = 1;
|
||||
Product product = 2;
|
||||
|
||||
message Product {
|
||||
int64 id = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message MultipleCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
Order order = 2;
|
||||
|
||||
message Order {
|
||||
string id = 1;
|
||||
}
|
||||
}
|
||||
|
||||
message RepeatedCase {
|
||||
repeated string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
}
|
||||
|
||||
message NonPrimitiveMessageCase {
|
||||
User userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
|
||||
message User {
|
||||
string id = 1;
|
||||
}
|
||||
}
|
||||
|
||||
message NonPrimitiveMapCase {
|
||||
map<string, string> userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
}
|
||||
|
||||
message PatternCase {
|
||||
string pattern = 1;
|
||||
}
|
||||
|
||||
message CompositePatternCase {
|
||||
string pattern = 1;
|
||||
string orderId = 2 [json_name = "order_id"];
|
||||
string productId = 3 [json_name = "product_id"];
|
||||
}
|
||||
}
|
||||
|
||||
message Test_QueryOnly {
|
||||
message PrimitiveCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
bool flag = 3;
|
||||
}
|
||||
|
||||
message RepeatedCase {
|
||||
repeated string strings = 1;
|
||||
repeated int64 integers = 2;
|
||||
}
|
||||
|
||||
message NestedMessageCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
Filter filter = 2;
|
||||
|
||||
message Filter {
|
||||
int64 age = 1;
|
||||
string name = 2;
|
||||
SubFilter subFilter = 3 [json_name = "sub_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message NestedMapCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
map<string, string> firstFilter = 2 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 4 [json_name = "second_filter"];
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
|
||||
message MultipleCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
repeated string strings = 2;
|
||||
Filter firstFilter = 3 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 4 [json_name = "second_filter"];
|
||||
|
||||
message Filter {
|
||||
int64 age = 1;
|
||||
SubFilter subFilter = 2 [json_name = "sub_filter"];
|
||||
}
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
}
|
||||
}
|
||||
|
||||
message RepeatedMessageCase {
|
||||
repeated Filter filters = 1;
|
||||
message Filter {
|
||||
int64 age = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message Test_BodyOnly {
|
||||
message PrimitiveCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
bool flag = 3;
|
||||
repeated string strings = 4;
|
||||
Product product = 6;
|
||||
|
||||
message Product {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message NestedCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
Filter first_filter = 2 [json_name = "first_filter"];
|
||||
Filter second_filter = 3 [json_name = "second_filter"];
|
||||
|
||||
message Filter {
|
||||
int64 age = 1;
|
||||
string name = 2;
|
||||
SubFilter subFilter = 3 [json_name = "sub_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message RepeatedMessageCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
repeated Product products = 2 [json_name = "products"];
|
||||
|
||||
message Product {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message MapCase {
|
||||
map<string, string> firstFilter = 1 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 2 [json_name = "second_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
|
||||
message MultipleCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
repeated SubFilter firstFilter = 2 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 3 [json_name = "second_filter"];
|
||||
SubFilter thirdFilter = 4 [json_name = "third_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message Test_Mixed {
|
||||
message PrimitiveCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
Product product = 3;
|
||||
|
||||
message Product {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message NestedCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
Filter first_filter = 2 [json_name = "first_filter"];
|
||||
Filter second_filter = 3 [json_name = "second_filter"];
|
||||
|
||||
message Filter {
|
||||
int64 age = 1;
|
||||
string name = 2;
|
||||
SubFilter subFilter = 3 [json_name = "sub_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message RepeatedMessageCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
repeated Product products = 2 [json_name = "products"];
|
||||
|
||||
message Product {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message MapCase {
|
||||
map<string, string> firstFilter = 1 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 2 [json_name = "second_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
|
||||
message MultipleCase {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
repeated SubFilter firstFilter = 2 [json_name = "first_filter"];
|
||||
map<string, SubFilter> secondFilter = 3 [json_name = "second_filter"];
|
||||
SubFilter thirdFilter = 4 [json_name = "third_filter"];
|
||||
|
||||
message SubFilter {
|
||||
int64 subAge = 1 [json_name = "sub_age"];
|
||||
string subName = 2 [json_name = "sub_name"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message Benchmark {
|
||||
message Case5 {
|
||||
string field1 = 1;
|
||||
string field2 = 2;
|
||||
string field3 = 3;
|
||||
string field4 = 4;
|
||||
string field5 = 5;
|
||||
}
|
||||
message Case10 {
|
||||
string field1 = 1;
|
||||
string field2 = 2;
|
||||
string field3 = 3;
|
||||
string field4 = 4;
|
||||
string field5 = 5;
|
||||
string field6 = 6;
|
||||
string field7 = 7;
|
||||
string field8 = 8;
|
||||
string field9 = 9;
|
||||
string field10 = 10;
|
||||
}
|
||||
message Case30 {
|
||||
string field1 = 1;
|
||||
string field2 = 2;
|
||||
string field3 = 3;
|
||||
string field4 = 4;
|
||||
string field5 = 5;
|
||||
string field6 = 6;
|
||||
string field7 = 7;
|
||||
string field8 = 8;
|
||||
string field9 = 9;
|
||||
string field10 = 10;
|
||||
string field11 = 11;
|
||||
string field12 = 12;
|
||||
string field13 = 13;
|
||||
string field14 = 14;
|
||||
string field15 = 15;
|
||||
string field16 = 16;
|
||||
string field17 = 17;
|
||||
string field18 = 18;
|
||||
string field19 = 19;
|
||||
string field20 = 20;
|
||||
string field21 = 21;
|
||||
string field22 = 22;
|
||||
string field23 = 23;
|
||||
string field24 = 24;
|
||||
string field25 = 25;
|
||||
string field26 = 26;
|
||||
string field27 = 27;
|
||||
string field28 = 28;
|
||||
string field29 = 29;
|
||||
string field30 = 30;
|
||||
}
|
||||
}
|
||||
|
||||
message Test_Client_Call {
|
||||
message Request {
|
||||
string userId = 1 [json_name = "user_id"];
|
||||
int64 orderId = 2 [json_name = "order_id"];
|
||||
}
|
||||
message Response {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
}
|
||||
message DefaultError {
|
||||
string code = 1;
|
||||
string msg = 2;
|
||||
}
|
||||
message SpecialError {
|
||||
string code = 1;
|
||||
string msg = 2;
|
||||
string warning = 3;
|
||||
}
|
||||
}
|
194
builder/query.go
Normal file
194
builder/query.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
func buildQuery(msg proto.Message, usedFieldsPath *usedFields, usedFieldBody string) (url.Values, error) {
|
||||
var (
|
||||
query = url.Values{}
|
||||
msgReflect = msg.ProtoReflect()
|
||||
)
|
||||
|
||||
fields := msgReflect.Descriptor().Fields()
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
var (
|
||||
fd = fields.Get(i)
|
||||
fieldName = fd.JSONName()
|
||||
)
|
||||
|
||||
if usedFieldsPath.hasFullKey(fieldName) {
|
||||
continue
|
||||
}
|
||||
if fieldName == usedFieldBody {
|
||||
continue
|
||||
}
|
||||
|
||||
val := msgReflect.Get(fd)
|
||||
if isZeroValue(val, fd) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Note: order of the cases is important!
|
||||
switch {
|
||||
case fd.IsList():
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
return nil, fmt.Errorf("repeated message field %s cannot be mapped to URL query parameters", fieldName)
|
||||
}
|
||||
list := val.List()
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
strVal, err := stringifyValue(list.Get(j), fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stringify value for query %s: %w", fieldName, err)
|
||||
}
|
||||
query.Add(fieldName, strVal)
|
||||
}
|
||||
|
||||
case fd.IsMap():
|
||||
var (
|
||||
m = val.Map()
|
||||
rangeErr error
|
||||
)
|
||||
m.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||
key := fmt.Sprintf("%s.%s", fieldName, k.String())
|
||||
|
||||
if fd.MapValue().Kind() == protoreflect.MessageKind {
|
||||
flattened, err := flattenMsgForQuery(key, v.Message())
|
||||
if err != nil {
|
||||
rangeErr = fmt.Errorf("flatten msg for query %s: %w", fieldName, err)
|
||||
return false
|
||||
}
|
||||
for _, item := range flattened {
|
||||
if item.val == "" {
|
||||
continue
|
||||
}
|
||||
if usedFieldsPath.hasFullKey(item.key) {
|
||||
continue
|
||||
}
|
||||
query.Add(item.key, item.val)
|
||||
}
|
||||
} else {
|
||||
strVal, err := stringifyValue(v, fd.MapValue())
|
||||
if err != nil {
|
||||
rangeErr = fmt.Errorf("stringify value for map %s: %w", fieldName, err)
|
||||
return false
|
||||
}
|
||||
query.Add(key, strVal)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if rangeErr != nil {
|
||||
return nil, fmt.Errorf("map range error: %w", rangeErr)
|
||||
}
|
||||
|
||||
case fd.Kind() == protoreflect.MessageKind:
|
||||
flattened, err := flattenMsgForQuery(fieldName, val.Message())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("flatten msg for query %s: %w", fieldName, err)
|
||||
}
|
||||
for _, item := range flattened {
|
||||
if item.val == "" {
|
||||
continue
|
||||
}
|
||||
if usedFieldsPath.hasFullKey(item.key) {
|
||||
continue
|
||||
}
|
||||
query.Add(item.key, item.val)
|
||||
}
|
||||
|
||||
default:
|
||||
strVal, err := stringifyValue(val, fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stringify value for primitive %s: %w", fieldName, err)
|
||||
}
|
||||
query.Add(fieldName, strVal)
|
||||
}
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
type flattenItem struct {
|
||||
key string
|
||||
val string
|
||||
}
|
||||
|
||||
// flattenMsgForQuery flattens a non-repeated message value under a given prefix.
|
||||
func flattenMsgForQuery(prefix string, msg protoreflect.Message) ([]flattenItem, error) {
|
||||
var out []flattenItem
|
||||
|
||||
fields := msg.Descriptor().Fields()
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
var (
|
||||
fd = fields.Get(i)
|
||||
val = msg.Get(fd)
|
||||
)
|
||||
|
||||
if isZeroValue(val, fd) {
|
||||
continue
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s.%s", prefix, fd.JSONName())
|
||||
|
||||
switch {
|
||||
case fd.IsList():
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
return nil, fmt.Errorf("repeated message field %s cannot be flattened for query", key)
|
||||
}
|
||||
list := val.List()
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
strVal, err := stringifyValue(list.Get(j), fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stringify query %s: %w", key, err)
|
||||
}
|
||||
out = append(out, flattenItem{key: key, val: strVal})
|
||||
}
|
||||
|
||||
case fd.Kind() == protoreflect.MessageKind:
|
||||
nested, err := flattenMsgForQuery(key, val.Message())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("flatten msg for query %s: %w", key, err)
|
||||
}
|
||||
out = append(out, nested...)
|
||||
|
||||
case fd.IsMap():
|
||||
var mapErr error
|
||||
val.Map().Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||
keyStr := k.String()
|
||||
|
||||
if fd.MapValue().Kind() == protoreflect.MessageKind {
|
||||
child, err := flattenMsgForQuery(keyStr, v.Message())
|
||||
if err != nil {
|
||||
mapErr = fmt.Errorf("flatten map value %s: %w", key, err)
|
||||
return false
|
||||
}
|
||||
out = append(out, child...)
|
||||
} else {
|
||||
strVal, err := stringifyValue(v, fd.MapValue())
|
||||
if err != nil {
|
||||
mapErr = fmt.Errorf("stringify query %s: %w", keyStr, err)
|
||||
return false
|
||||
}
|
||||
out = append(out, flattenItem{key: keyStr, val: strVal})
|
||||
}
|
||||
return true
|
||||
})
|
||||
if mapErr != nil {
|
||||
return nil, mapErr
|
||||
}
|
||||
|
||||
default:
|
||||
strVal, err := stringifyValue(val, fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stringify query %s: %w", key, err)
|
||||
}
|
||||
out = append(out, flattenItem{key: key, val: strVal})
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
180
builder/request_builder.go
Normal file
180
builder/request_builder.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// Package builder implements google.api.http-style request building (gRPC JSON transcoding)
|
||||
// for HTTP requests, closely following the google.api.http spec.
|
||||
// See full spec for details: https://github.com/googleapis/googleapis/blob/master/google/api/http.proto
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
type RequestBuilder struct {
|
||||
path string // e.g. "/v1/{name=projects/*/topics/*}:publish" or "/users/{user.id}"
|
||||
method string // GET, POST, PATCH, etc. (not used in mapping rules, but convenient for callers)
|
||||
bodyOption bodyOption // "", "*", or top-level field name
|
||||
msg proto.Message // request struct
|
||||
}
|
||||
|
||||
func NewRequestBuilder(
|
||||
path string,
|
||||
method string,
|
||||
bodyOpt string,
|
||||
msg proto.Message,
|
||||
) (
|
||||
*RequestBuilder,
|
||||
error,
|
||||
) {
|
||||
rb := &RequestBuilder{
|
||||
path: path,
|
||||
method: method,
|
||||
bodyOption: bodyOption(bodyOpt),
|
||||
msg: msg,
|
||||
}
|
||||
|
||||
if err := rb.validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
|
||||
return rb, nil
|
||||
}
|
||||
|
||||
// Build applies mapping rules and returns:
|
||||
//
|
||||
// resolvedPath — path with placeholders substituted and query appended
|
||||
// newMsg — same concrete type as input, filtered to contain only the body fields
|
||||
// err — if mapping/validation failed
|
||||
func (b *RequestBuilder) Build() (resolvedPath string, newMsg proto.Message, err error) {
|
||||
tmpl, isCached := getCachedPathTemplate(b.path)
|
||||
if !isCached {
|
||||
tmpl, err = parsePathTemplate(b.path)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("parse path template: %w", err)
|
||||
}
|
||||
setPathTemplateCache(b.path, tmpl)
|
||||
}
|
||||
|
||||
var usedFieldsPath *usedFields
|
||||
resolvedPath, usedFieldsPath, err = resolvePathPlaceholders(tmpl, b.msg)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("resolve path placeholders: %w", err)
|
||||
}
|
||||
|
||||
// if all set fields are already used in path, no need to process query/body
|
||||
if allFieldsUsed(b.msg, usedFieldsPath) {
|
||||
return resolvedPath, initZeroMsg(b.msg), nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case b.bodyOption.isWithoutBody():
|
||||
var query url.Values
|
||||
query, err = buildQuery(b.msg, usedFieldsPath, "")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("build query: %w", err)
|
||||
}
|
||||
|
||||
return resolvedPath + encodeQuery(query), initZeroMsg(b.msg), nil
|
||||
|
||||
case b.bodyOption.isSingleField():
|
||||
fieldBody := b.bodyOption.String()
|
||||
|
||||
newMsg, err = buildSingleFieldBody(b.msg, fieldBody)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("build single field body: %w", err)
|
||||
}
|
||||
|
||||
var query url.Values
|
||||
query, err = buildQuery(b.msg, usedFieldsPath, fieldBody)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("build query: %w", err)
|
||||
}
|
||||
|
||||
return resolvedPath + encodeQuery(query), newMsg, nil
|
||||
|
||||
case b.bodyOption.isFullBody():
|
||||
newMsg, err = buildFullBody(b.msg, usedFieldsPath)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("build full body: %w", err)
|
||||
}
|
||||
|
||||
return resolvedPath, newMsg, nil
|
||||
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported body option %s", b.bodyOption.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RequestBuilder) validate() error {
|
||||
if b.path == "" {
|
||||
return errors.New("path is empty")
|
||||
}
|
||||
if err := validateHTTPMethod(b.method); err != nil {
|
||||
return fmt.Errorf("validate http method: %w", err)
|
||||
}
|
||||
if err := validateHTTPMethodAndBody(b.method, b.bodyOption); err != nil {
|
||||
return fmt.Errorf("validate http method and body: %w", err)
|
||||
}
|
||||
if b.msg == nil {
|
||||
return errors.New("msg is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHTTPMethod(method string) error {
|
||||
switch strings.ToUpper(method) {
|
||||
case http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodPatch,
|
||||
http.MethodDelete,
|
||||
http.MethodConnect,
|
||||
http.MethodOptions,
|
||||
http.MethodTrace:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid http method")
|
||||
}
|
||||
}
|
||||
|
||||
func validateHTTPMethodAndBody(method string, bodyOpt bodyOption) error {
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodDelete, http.MethodHead, http.MethodOptions:
|
||||
if !bodyOpt.isWithoutBody() {
|
||||
return fmt.Errorf("%s method must not have a body", method)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func allFieldsUsed(msg proto.Message, used *usedFields) bool {
|
||||
if used.len() == 0 {
|
||||
return false
|
||||
}
|
||||
count := 0
|
||||
msg.ProtoReflect().Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return used.len() == count
|
||||
}
|
||||
|
||||
func encodeQuery(query url.Values) string {
|
||||
if len(query) == 0 {
|
||||
return ""
|
||||
}
|
||||
enc := query.Encode()
|
||||
if enc == "" {
|
||||
return ""
|
||||
}
|
||||
return "?" + enc
|
||||
}
|
||||
|
||||
func initZeroMsg(msg proto.Message) proto.Message {
|
||||
return msg.ProtoReflect().New().Interface()
|
||||
}
|
149
builder/request_builder_bench_test.go
Normal file
149
builder/request_builder_bench_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
pb "go.unistack.org/micro-client-http/v4/builder/proto"
|
||||
)
|
||||
|
||||
// sink prevents the compiler from optimizing away parsePathTemplate results.
|
||||
var sink *pathTemplate
|
||||
|
||||
func BenchmarkParsePathTemplate(b *testing.B) {
|
||||
r := rand.New(rand.NewSource(1))
|
||||
|
||||
benchInput := func(size int) string {
|
||||
sb := strings.Builder{}
|
||||
sb.Grow(size * 10)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
name := fmt.Sprintf("var%d", r.Intn(1000))
|
||||
|
||||
if r.Intn(5) == 0 {
|
||||
sb.WriteString(fmt.Sprintf("{%s=**}", name))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("{%s}", name))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
sizes := []int{1_000, 10_000, 50_000, 100_000}
|
||||
for _, size := range sizes {
|
||||
input := benchInput(size)
|
||||
b.Run(fmt.Sprintf("N=%d", size), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var err error
|
||||
sink, err = parsePathTemplate(input)
|
||||
if err != nil && testing.Verbose() {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRequestBuilder(b *testing.B) {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
makeMsg := func(fieldCount int) proto.Message {
|
||||
switch fieldCount {
|
||||
case 5:
|
||||
return &pb.Benchmark_Case5{
|
||||
Field1: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field2: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field3: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field4: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field5: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
}
|
||||
case 10:
|
||||
return &pb.Benchmark_Case10{
|
||||
Field1: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field2: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field3: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field4: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field5: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field6: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field7: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field8: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field9: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field10: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
}
|
||||
case 30:
|
||||
return &pb.Benchmark_Case30{
|
||||
Field1: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field2: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field3: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field4: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field5: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field6: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field7: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field8: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field9: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field10: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field11: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field12: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field13: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field14: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field15: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field16: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field17: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field18: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field19: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field20: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field21: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field22: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field23: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field24: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field25: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field26: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field27: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field28: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field29: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
Field30: fmt.Sprintf("value%d", r.Intn(1000)),
|
||||
}
|
||||
default:
|
||||
b.Fatal("undefined field count")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pathTmpl string
|
||||
bodyOption string
|
||||
}{
|
||||
{"all fields in path", "/resource/{field1}/{field2}", ""},
|
||||
{"single field body", "/resource/{field1}", "field4"},
|
||||
{"full body", "/resource", "*"},
|
||||
}
|
||||
|
||||
for _, fields := range []int{5, 10, 30} {
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s_%d_fields", tt.name, fields), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg := makeMsg(fields)
|
||||
rb, err := NewRequestBuilder(tt.pathTmpl, "POST", tt.bodyOption, msg)
|
||||
if err != nil {
|
||||
b.Fatalf("new request builder: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = rb.Build()
|
||||
if err != nil {
|
||||
b.Fatalf("build: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
1302
builder/request_builder_test.go
Normal file
1302
builder/request_builder_test.go
Normal file
File diff suppressed because it is too large
Load Diff
44
builder/used_fields.go
Normal file
44
builder/used_fields.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package builder
|
||||
|
||||
import "strings"
|
||||
|
||||
// usedFields stores keys and their top-level parts,
|
||||
// turning top-level lookups from O(N) into O(1).
|
||||
type usedFields struct {
|
||||
full map[string]struct{}
|
||||
top map[string]struct{}
|
||||
}
|
||||
|
||||
func newUsedFields() *usedFields {
|
||||
return &usedFields{
|
||||
full: make(map[string]struct{}),
|
||||
top: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add inserts a new key and updates the top-level index.
|
||||
func (u *usedFields) add(key string) {
|
||||
u.full[key] = struct{}{}
|
||||
top := key
|
||||
if i := strings.IndexByte(key, '.'); i != -1 {
|
||||
top = key[:i]
|
||||
}
|
||||
u.top[top] = struct{}{}
|
||||
}
|
||||
|
||||
// hasTopLevelKey checks if a top-level key exists.
|
||||
func (u *usedFields) hasTopLevelKey(top string) bool {
|
||||
_, ok := u.top[top]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasFullKey checks if an exact key exists.
|
||||
func (u *usedFields) hasFullKey(key string) bool {
|
||||
_, ok := u.full[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// len returns the number of full keys stored in the set.
|
||||
func (u *usedFields) len() int {
|
||||
return len(u.full)
|
||||
}
|
78
builder/used_fields_test.go
Normal file
78
builder/used_fields_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewUsedFields(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
require.NotNil(t, u)
|
||||
require.NotNil(t, u.full)
|
||||
require.NotNil(t, u.top)
|
||||
require.Len(t, u.full, 0)
|
||||
require.Len(t, u.top, 0)
|
||||
}
|
||||
|
||||
func TestUsedFields_Add(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
|
||||
u.add("user.name")
|
||||
u.add("profile")
|
||||
|
||||
_, ok := u.full["user.name"]
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = u.full["profile"]
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = u.top["user"]
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = u.top["profile"]
|
||||
require.True(t, ok)
|
||||
|
||||
require.Len(t, u.full, 2)
|
||||
require.Len(t, u.top, 2)
|
||||
}
|
||||
|
||||
func TestUsedFields_HasFullKey(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
u.add("user.name")
|
||||
|
||||
require.True(t, u.hasFullKey("user.name"))
|
||||
require.False(t, u.hasFullKey("user.email"))
|
||||
}
|
||||
|
||||
func TestUsedFields_HasTopLevelKey(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
u.add("user.name")
|
||||
u.add("settings.theme")
|
||||
|
||||
require.True(t, u.hasTopLevelKey("user"))
|
||||
require.True(t, u.hasTopLevelKey("settings"))
|
||||
require.False(t, u.hasTopLevelKey("profile"))
|
||||
}
|
||||
|
||||
func TestUsedFields_AddDuplicate(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
u.add("user.name")
|
||||
u.add("user.name")
|
||||
|
||||
require.True(t, u.hasFullKey("user.name"))
|
||||
require.True(t, u.hasTopLevelKey("user"))
|
||||
require.Len(t, u.full, 1)
|
||||
require.Len(t, u.top, 1)
|
||||
}
|
||||
|
||||
func TestUsedFields_Len(t *testing.T) {
|
||||
u := newUsedFields()
|
||||
|
||||
u.add("user.name")
|
||||
u.add("profile")
|
||||
u.add("user.name")
|
||||
u.add("profile")
|
||||
|
||||
require.Equal(t, u.len(), 2)
|
||||
}
|
Reference in New Issue
Block a user