logger/unwrap: fix for nested tagged/untagged #183

Merged
vtolstov merged 1 commits from logger/unwrap into v3 2023-02-08 14:56:51 +03:00

View File

@ -46,6 +46,11 @@ var (
closeMapBytes = []byte("}") closeMapBytes = []byte("}")
) )
type protoMessage interface {
Reset()
ProtoMessage()
}
type Wrapper struct { type Wrapper struct {
val interface{} val interface{}
s fmt.State s fmt.State
@ -53,7 +58,7 @@ type Wrapper struct {
opts *Options opts *Options
depth int depth int
ignoreNextType bool ignoreNextType bool
takeAll map[int]bool takeMap map[int]bool
protoWrapperType bool protoWrapperType bool
sqlWrapperType bool sqlWrapperType bool
} }
@ -111,7 +116,7 @@ func Tagged(b bool) Option {
func Unwrap(val interface{}, opts ...Option) *Wrapper { func Unwrap(val interface{}, opts ...Option) *Wrapper {
options := NewOptions(opts...) options := NewOptions(opts...)
return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeAll: make(map[int]bool)} return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeMap: make(map[int]bool)}
} }
func (w *Wrapper) unpackValue(v reflect.Value) reflect.Value { func (w *Wrapper) unpackValue(v reflect.Value) reflect.Value {
@ -237,9 +242,6 @@ func (w *Wrapper) format(v reflect.Value) {
_, _ = w.s.Write(buf) _, _ = w.s.Write(buf)
return return
} }
if w.opts.Tagged {
w.checkTakeAll(v, 1)
}
// Handle invalid reflect values immediately. // Handle invalid reflect values immediately.
kind := v.Kind() kind := v.Kind()
@ -256,6 +258,10 @@ func (w *Wrapper) format(v reflect.Value) {
w.protoWrapperType = true w.protoWrapperType = true
} else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { } else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") {
w.sqlWrapperType = true w.sqlWrapperType = true
} else if v.CanInterface() {
if _, ok := v.Interface().(protoMessage); ok {
w.protoWrapperType = true
}
} }
} }
w.formatPtr(v) w.formatPtr(v)
@ -378,6 +384,12 @@ func (w *Wrapper) format(v reflect.Value) {
prevSkip := false prevSkip := false
for i := 0; i < numFields; i++ { for i := 0; i < numFields; i++ {
switch vt.Field(i).Type.PkgPath() {
case "google.golang.org/protobuf/internal/impl", "google.golang.org/protobuf/internal/pragma":
w.protoWrapperType = true
prevSkip = true
continue
}
if w.protoWrapperType && !vt.Field(i).IsExported() { if w.protoWrapperType && !vt.Field(i).IsExported() {
prevSkip = true prevSkip = true
continue continue
@ -385,6 +397,9 @@ func (w *Wrapper) format(v reflect.Value) {
prevSkip = true prevSkip = true
continue continue
} }
if _, ok := vt.Field(i).Tag.Lookup("protobuf"); ok && !w.protoWrapperType {
w.protoWrapperType = true
}
sv, ok := vt.Field(i).Tag.Lookup("logger") sv, ok := vt.Field(i).Tag.Lookup("logger")
switch { switch {
case ok: case ok:
@ -395,12 +410,17 @@ func (w *Wrapper) format(v reflect.Value) {
case "take": case "take":
break break
} }
case w.takeAll[w.depth]:
break
case !ok && w.opts.Tagged: case !ok && w.opts.Tagged:
// skip top level untagged
if w.depth == 1 {
prevSkip = true prevSkip = true
continue continue
} }
if tv, ok := w.takeMap[w.depth]; ok && !tv {
prevSkip = true
continue
}
}
if prevSkip { if prevSkip {
prevSkip = false prevSkip = false
@ -416,9 +436,7 @@ func (w *Wrapper) format(v reflect.Value) {
_, _ = w.s.Write([]byte(vt.Name)) _, _ = w.s.Write([]byte(vt.Name))
_, _ = w.s.Write(colonBytes) _, _ = w.s.Write(colonBytes)
} }
unpackValue := w.unpackValue(v.Field(i)) w.format(w.unpackValue(v.Field(i)))
w.checkTakeAll(unpackValue, w.depth)
w.format(unpackValue)
numWritten++ numWritten++
} }
w.depth-- w.depth--
@ -461,6 +479,10 @@ func (w *Wrapper) Format(s fmt.State, verb rune) {
return return
} }
if w.opts.Tagged {
w.buildTakeMap(reflect.ValueOf(w.val), 1)
}
w.format(reflect.ValueOf(w.val)) w.format(reflect.ValueOf(w.val))
} }
@ -615,24 +637,28 @@ func (w *Wrapper) constructOrigFormat(verb rune) string {
return buf.String() return buf.String()
} }
func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) { func (w *Wrapper) buildTakeMap(v reflect.Value, depth int) {
if _, ok := w.takeAll[depth]; ok {
return
}
if !v.IsValid() || v.IsZero() { if !v.IsValid() || v.IsZero() {
return return
} }
switch v.Kind() { switch v.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
w.buildTakeMap(v.Index(i), depth+1)
}
w.takeMap[depth] = true
return
case reflect.Struct: case reflect.Struct:
break break
case reflect.Ptr: case reflect.Ptr:
v = v.Elem() v = v.Elem()
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
w.takeAll[depth] = true w.takeMap[depth] = true
return return
} }
default: default:
w.takeAll[depth] = true w.takeMap[depth] = true
return return
} }
@ -641,8 +667,15 @@ func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) {
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
sv, ok := vt.Field(i).Tag.Lookup("logger") sv, ok := vt.Field(i).Tag.Lookup("logger")
if ok && sv == "take" { if ok && sv == "take" {
w.takeAll[depth] = false w.takeMap[depth] = false
} }
w.checkTakeAll(v.Field(i), depth+1) if v.Kind() == reflect.Struct ||
(v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct) {
w.buildTakeMap(v.Field(i), depth+1)
}
}
if _, ok := w.takeMap[depth]; !ok {
w.takeMap[depth] = true
} }
} }