logger/unwrap: check nested in case of Tagged

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2023-02-06 22:36:24 +03:00
parent 84024f7713
commit 9b387312da

View File

@ -46,14 +46,14 @@ var (
closeMapBytes = []byte("}") closeMapBytes = []byte("}")
) )
type unwrap struct { type Wrapper struct {
val interface{} val interface{}
s fmt.State s fmt.State
pointers map[uintptr]int pointers map[uintptr]int
opts *Options opts *Options
depth int depth int
ignoreNextType bool ignoreNextType bool
takeAll bool takeAll map[int]bool
protoWrapperType bool protoWrapperType bool
sqlWrapperType bool sqlWrapperType bool
} }
@ -109,14 +109,14 @@ func Tagged(b bool) Option {
} }
} }
func Unwrap(val interface{}, opts ...Option) *unwrap { func Unwrap(val interface{}, opts ...Option) *Wrapper {
options := NewOptions(opts...) options := NewOptions(opts...)
return &unwrap{val: val, opts: &options, pointers: make(map[uintptr]int)} return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeAll: make(map[int]bool)}
} }
func (f *unwrap) unpackValue(v reflect.Value) reflect.Value { func (w *Wrapper) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
f.ignoreNextType = false w.ignoreNextType = false
if !v.IsNil() { if !v.IsNil() {
v = v.Elem() v = v.Elem()
} }
@ -125,19 +125,19 @@ func (f *unwrap) unpackValue(v reflect.Value) reflect.Value {
} }
// formatPtr handles formatting of pointers by indirecting them as necessary. // formatPtr handles formatting of pointers by indirecting them as necessary.
func (f *unwrap) formatPtr(v reflect.Value) { func (w *Wrapper) formatPtr(v reflect.Value) {
// Display nil if top level pointer is nil. // Display nil if top level pointer is nil.
showTypes := f.s.Flag('#') showTypes := w.s.Flag('#')
if v.IsNil() && (!showTypes || f.ignoreNextType) { if v.IsNil() && (!showTypes || w.ignoreNextType) {
_, _ = f.s.Write(nilAngleBytes) _, _ = w.s.Write(nilAngleBytes)
return return
} }
// Remove pointers at or below the current depth from map used to detect // Remove pointers at or below the current depth from map used to detect
// circular refs. // circular refs.
for k, depth := range f.pointers { for k, depth := range w.pointers {
if depth >= f.depth { if depth >= w.depth {
delete(f.pointers, k) delete(w.pointers, k)
} }
} }
@ -159,12 +159,12 @@ func (f *unwrap) formatPtr(v reflect.Value) {
indirects++ indirects++
addr := ve.Pointer() addr := ve.Pointer()
pointerChain = append(pointerChain, addr) pointerChain = append(pointerChain, addr)
if pd, ok := f.pointers[addr]; ok && pd < f.depth { if pd, ok := w.pointers[addr]; ok && pd < w.depth {
cycleFound = true cycleFound = true
indirects-- indirects--
break break
} }
f.pointers[addr] = f.depth w.pointers[addr] = w.depth
ve = ve.Elem() ve = ve.Elem()
if ve.Kind() == reflect.Interface { if ve.Kind() == reflect.Interface {
@ -177,49 +177,49 @@ func (f *unwrap) formatPtr(v reflect.Value) {
} }
// Display type or indirection level depending on flags. // Display type or indirection level depending on flags.
if showTypes && !f.ignoreNextType { if showTypes && !w.ignoreNextType {
if f.depth > 0 { if w.depth > 0 {
_, _ = f.s.Write(openParenBytes) _, _ = w.s.Write(openParenBytes)
} }
if f.depth > 0 { if w.depth > 0 {
_, _ = f.s.Write(bytes.Repeat(asteriskBytes, indirects)) _, _ = w.s.Write(bytes.Repeat(asteriskBytes, indirects))
} else { } else {
_, _ = f.s.Write(bytes.Repeat(ampBytes, indirects)) _, _ = w.s.Write(bytes.Repeat(ampBytes, indirects))
} }
_, _ = f.s.Write([]byte(ve.Type().String())) _, _ = w.s.Write([]byte(ve.Type().String()))
if f.depth > 0 { if w.depth > 0 {
_, _ = f.s.Write(closeParenBytes) _, _ = w.s.Write(closeParenBytes)
} }
} else { } else {
if nilFound || cycleFound { if nilFound || cycleFound {
indirects += strings.Count(ve.Type().String(), "*") indirects += strings.Count(ve.Type().String(), "*")
} }
_, _ = f.s.Write(openAngleBytes) _, _ = w.s.Write(openAngleBytes)
_, _ = f.s.Write([]byte(strings.Repeat("*", indirects))) _, _ = w.s.Write([]byte(strings.Repeat("*", indirects)))
_, _ = f.s.Write(closeAngleBytes) _, _ = w.s.Write(closeAngleBytes)
} }
// Display pointer information depending on flags. // Display pointer information depending on flags.
if f.s.Flag('+') && (len(pointerChain) > 0) { if w.s.Flag('+') && (len(pointerChain) > 0) {
_, _ = f.s.Write(openParenBytes) _, _ = w.s.Write(openParenBytes)
for i, addr := range pointerChain { for i, addr := range pointerChain {
if i > 0 { if i > 0 {
_, _ = f.s.Write(pointerChainBytes) _, _ = w.s.Write(pointerChainBytes)
} }
getHexPtr(f.s, addr) getHexPtr(w.s, addr)
} }
_, _ = f.s.Write(closeParenBytes) _, _ = w.s.Write(closeParenBytes)
} }
// Display dereferenced value. // Display dereferenced value.
switch { switch {
case nilFound: case nilFound:
_, _ = f.s.Write(nilAngleBytes) _, _ = w.s.Write(nilAngleBytes)
case cycleFound: case cycleFound:
_, _ = f.s.Write(circularShortBytes) _, _ = w.s.Write(circularShortBytes)
default: default:
f.ignoreNextType = true w.ignoreNextType = true
f.format(ve) w.format(ve)
} }
} }
@ -227,20 +227,24 @@ func (f *unwrap) formatPtr(v reflect.Value) {
// uses the passed reflect value to figure out what kind of object we are // uses the passed reflect value to figure out what kind of object we are
// dealing with and formats it appropriately. It is a recursive function, // dealing with and formats it appropriately. It is a recursive function,
// however circular data structures are detected and handled properly. // however circular data structures are detected and handled properly.
func (f *unwrap) format(v reflect.Value) { func (w *Wrapper) format(v reflect.Value) {
if f.opts.Codec != nil { if w.opts.Codec != nil {
buf, err := f.opts.Codec.Marshal(v.Interface()) buf, err := w.opts.Codec.Marshal(v.Interface())
if err != nil { if err != nil {
_, _ = f.s.Write(invalidAngleBytes) _, _ = w.s.Write(invalidAngleBytes)
return return
} }
_, _ = f.s.Write(buf) _, _ = w.s.Write(buf)
return return
} }
if w.opts.Tagged {
w.checkTakeAll(v, 1)
}
// Handle invalid reflect values immediately. // Handle invalid reflect values immediately.
kind := v.Kind() kind := v.Kind()
if kind == reflect.Invalid { if kind == reflect.Invalid {
_, _ = f.s.Write(invalidAngleBytes) _, _ = w.s.Write(invalidAngleBytes)
return return
} }
@ -249,46 +253,46 @@ func (f *unwrap) format(v reflect.Value) {
case reflect.Ptr: case reflect.Ptr:
if !v.IsZero() { if !v.IsZero() {
if strings.HasPrefix(reflect.Indirect(v).Type().String(), "wrapperspb.") { if strings.HasPrefix(reflect.Indirect(v).Type().String(), "wrapperspb.") {
f.protoWrapperType = true w.protoWrapperType = true
} else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { } else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") {
f.sqlWrapperType = true w.sqlWrapperType = true
} }
} }
f.formatPtr(v) w.formatPtr(v)
return return
case reflect.Struct: case reflect.Struct:
if !v.IsZero() { if !v.IsZero() {
if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") {
f.sqlWrapperType = true w.sqlWrapperType = true
} }
} }
} }
// get type information unless already handled elsewhere. // get type information unless already handled elsewhere.
if !f.ignoreNextType && f.s.Flag('#') { if !w.ignoreNextType && w.s.Flag('#') {
if v.Type().Kind() != reflect.Map && if v.Type().Kind() != reflect.Map &&
v.Type().Kind() != reflect.String && v.Type().Kind() != reflect.String &&
v.Type().Kind() != reflect.Array && v.Type().Kind() != reflect.Array &&
v.Type().Kind() != reflect.Slice { v.Type().Kind() != reflect.Slice {
_, _ = f.s.Write(openParenBytes) _, _ = w.s.Write(openParenBytes)
} }
if v.Kind() != reflect.String { if v.Kind() != reflect.String {
_, _ = f.s.Write([]byte(v.Type().String())) _, _ = w.s.Write([]byte(v.Type().String()))
} }
if v.Type().Kind() != reflect.Map && if v.Type().Kind() != reflect.Map &&
v.Type().Kind() != reflect.String && v.Type().Kind() != reflect.String &&
v.Type().Kind() != reflect.Array && v.Type().Kind() != reflect.Array &&
v.Type().Kind() != reflect.Slice { v.Type().Kind() != reflect.Slice {
_, _ = f.s.Write(closeParenBytes) _, _ = w.s.Write(closeParenBytes)
} }
} }
f.ignoreNextType = false w.ignoreNextType = false
// Call Stringer/error interfaces if they exist and the handle methods // Call Stringer/error interfaces if they exist and the handle methods
// flag is enabled. // flag is enabled.
if f.opts.Methods { if w.opts.Methods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) { if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(f.opts, f.s, v); handled { if handled := handleMethods(w.opts, w.s, v); handled {
return return
} }
} }
@ -296,48 +300,48 @@ func (f *unwrap) format(v reflect.Value) {
switch kind { switch kind {
case reflect.Invalid: case reflect.Invalid:
_, _ = f.s.Write(invalidAngleBytes) _, _ = w.s.Write(invalidAngleBytes)
case reflect.Bool: case reflect.Bool:
getBool(f.s, v.Bool()) getBool(w.s, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
getInt(f.s, v.Int(), 10) getInt(w.s, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
getUint(f.s, v.Uint(), 10) getUint(w.s, v.Uint(), 10)
case reflect.Float32: case reflect.Float32:
getFloat(f.s, v.Float(), 32) getFloat(w.s, v.Float(), 32)
case reflect.Float64: case reflect.Float64:
getFloat(f.s, v.Float(), 64) getFloat(w.s, v.Float(), 64)
case reflect.Complex64: case reflect.Complex64:
getComplex(f.s, v.Complex(), 32) getComplex(w.s, v.Complex(), 32)
case reflect.Complex128: case reflect.Complex128:
getComplex(f.s, v.Complex(), 64) getComplex(w.s, v.Complex(), 64)
case reflect.Slice: case reflect.Slice:
if v.IsNil() { if v.IsNil() {
_, _ = f.s.Write(nilAngleBytes) _, _ = w.s.Write(nilAngleBytes)
break break
} }
fallthrough fallthrough
case reflect.Array: case reflect.Array:
_, _ = f.s.Write(openBraceBytes) _, _ = w.s.Write(openBraceBytes)
f.depth++ w.depth++
numEntries := v.Len() numEntries := v.Len()
for i := 0; i < numEntries; i++ { for i := 0; i < numEntries; i++ {
if i > 0 { if i > 0 {
_, _ = f.s.Write(commaBytes) _, _ = w.s.Write(commaBytes)
_, _ = f.s.Write(spaceBytes) _, _ = w.s.Write(spaceBytes)
} }
f.ignoreNextType = true w.ignoreNextType = true
f.format(f.unpackValue(v.Index(i))) w.format(w.unpackValue(v.Index(i)))
} }
f.depth-- w.depth--
_, _ = f.s.Write(closeBraceBytes) _, _ = w.s.Write(closeBraceBytes)
case reflect.String: case reflect.String:
_, _ = f.s.Write([]byte(`"` + v.String() + `"`)) _, _ = w.s.Write([]byte(`"` + v.String() + `"`))
case reflect.Interface: case reflect.Interface:
// The only time we should get here is for nil interfaces due to // The only time we should get here is for nil interfaces due to
// unpackValue calls. // unpackValue calls.
if v.IsNil() { if v.IsNil() {
_, _ = f.s.Write(nilAngleBytes) _, _ = w.s.Write(nilAngleBytes)
} }
case reflect.Ptr: case reflect.Ptr:
// Do nothing. We should never get here since pointers have already // Do nothing. We should never get here since pointers have already
@ -345,38 +349,39 @@ func (f *unwrap) format(v reflect.Value) {
case reflect.Map: case reflect.Map:
// nil maps should be indicated as different than empty maps // nil maps should be indicated as different than empty maps
if v.IsNil() { if v.IsNil() {
_, _ = f.s.Write(nilAngleBytes) _, _ = w.s.Write(nilAngleBytes)
break break
} }
_, _ = f.s.Write(openMapBytes) _, _ = w.s.Write(openMapBytes)
f.depth++ w.depth++
keys := v.MapKeys() keys := v.MapKeys()
for i, key := range keys { for i, key := range keys {
if i > 0 { if i > 0 {
_, _ = f.s.Write(spaceBytes) _, _ = w.s.Write(spaceBytes)
} }
f.ignoreNextType = true w.ignoreNextType = true
f.format(f.unpackValue(key)) w.format(w.unpackValue(key))
_, _ = f.s.Write(colonBytes) _, _ = w.s.Write(colonBytes)
f.ignoreNextType = true w.ignoreNextType = true
f.format(f.unpackValue(v.MapIndex(key))) w.format(w.unpackValue(v.MapIndex(key)))
} }
f.depth-- w.depth--
_, _ = f.s.Write(closeMapBytes) _, _ = w.s.Write(closeMapBytes)
case reflect.Struct: case reflect.Struct:
numFields := v.NumField() numFields := v.NumField()
numWritten := 0 numWritten := 0
_, _ = f.s.Write(openBraceBytes) _, _ = w.s.Write(openBraceBytes)
f.depth++ w.depth++
vt := v.Type() vt := v.Type()
prevSkip := false prevSkip := false
for i := 0; i < numFields; i++ { for i := 0; i < numFields; i++ {
f.takeAll = false if w.protoWrapperType && !vt.Field(i).IsExported() {
if f.protoWrapperType && !vt.Field(i).IsExported() {
prevSkip = true prevSkip = true
continue continue
} else if f.sqlWrapperType && vt.Field(i).Name == "Valid" { } else if w.sqlWrapperType && vt.Field(i).Name == "Valid" {
prevSkip = true prevSkip = true
continue continue
} }
@ -390,9 +395,9 @@ func (f *unwrap) format(v reflect.Value) {
case "take": case "take":
break break
} }
case f.takeAll: case w.takeAll[w.depth]:
break break
case !ok && f.opts.Tagged: case !ok && w.opts.Tagged:
prevSkip = true prevSkip = true
continue continue
} }
@ -402,52 +407,53 @@ func (f *unwrap) format(v reflect.Value) {
} }
if numWritten > 0 { if numWritten > 0 {
_, _ = f.s.Write(commaBytes) _, _ = w.s.Write(commaBytes)
_, _ = f.s.Write(spaceBytes) _, _ = w.s.Write(spaceBytes)
} }
vtf := vt.Field(i) vt := vt.Field(i)
if f.s.Flag('+') || f.s.Flag('#') { if w.s.Flag('+') || w.s.Flag('#') {
_, _ = f.s.Write([]byte(vtf.Name)) _, _ = w.s.Write([]byte(vt.Name))
_, _ = f.s.Write(colonBytes) _, _ = w.s.Write(colonBytes)
} }
unpackValue := f.unpackValue(v.Field(i)) unpackValue := w.unpackValue(v.Field(i))
f.takeAll = f.checkTakeAll(unpackValue) w.checkTakeAll(unpackValue, w.depth)
f.format(unpackValue) w.format(unpackValue)
numWritten++ numWritten++
} }
f.depth-- w.depth--
if numWritten == 0 && f.depth < 0 {
_, _ = f.s.Write(filteredBytes) if numWritten == 0 && w.depth < 0 {
_, _ = w.s.Write(filteredBytes)
} }
_, _ = f.s.Write(closeBraceBytes) _, _ = w.s.Write(closeBraceBytes)
case reflect.Uintptr: case reflect.Uintptr:
getHexPtr(f.s, uintptr(v.Uint())) getHexPtr(w.s, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func: case reflect.UnsafePointer, reflect.Chan, reflect.Func:
getHexPtr(f.s, v.Pointer()) getHexPtr(w.s, v.Pointer())
// There were not any other types at the time this code was written, but // There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it if any get added. // fall back to letting the default fmt package handle it if any get added.
default: default:
format := f.buildDefaultFormat() format := w.buildDefaultFormat()
if v.CanInterface() { if v.CanInterface() {
_, _ = fmt.Fprintf(f.s, format, v.Interface()) _, _ = fmt.Fprintf(w.s, format, v.Interface())
} else { } else {
_, _ = fmt.Fprintf(f.s, format, v.String()) _, _ = fmt.Fprintf(w.s, format, v.String())
} }
} }
} }
func (f *unwrap) Format(s fmt.State, verb rune) { func (w *Wrapper) Format(s fmt.State, verb rune) {
f.s = s w.s = s
// Use standard formatting for verbs that are not v. // Use standard formatting for verbs that are not v.
if verb != 'v' { if verb != 'v' {
format := f.constructOrigFormat(verb) format := w.constructOrigFormat(verb)
_, _ = fmt.Fprintf(s, format, f.val) _, _ = fmt.Fprintf(s, format, w.val)
return return
} }
if f.val == nil { if w.val == nil {
if s.Flag('#') { if s.Flag('#') {
_, _ = s.Write(interfaceBytes) _, _ = s.Write(interfaceBytes)
} }
@ -455,7 +461,7 @@ func (f *unwrap) Format(s fmt.State, verb rune) {
return return
} }
f.format(reflect.ValueOf(f.val)) w.format(reflect.ValueOf(w.val))
} }
// handle special methods like error.Error() or fmt.Stringer interface // handle special methods like error.Error() or fmt.Stringer interface
@ -571,11 +577,11 @@ func catchPanic(w io.Writer, _ reflect.Value) {
} }
} }
func (f *unwrap) buildDefaultFormat() (format string) { func (w *Wrapper) buildDefaultFormat() (format string) {
buf := bytes.NewBuffer(percentBytes) buf := bytes.NewBuffer(percentBytes)
for _, flag := range sf { for _, flag := range sf {
if f.s.Flag(int(flag)) { if w.s.Flag(int(flag)) {
_, _ = buf.WriteRune(flag) _, _ = buf.WriteRune(flag)
} }
} }
@ -586,43 +592,48 @@ func (f *unwrap) buildDefaultFormat() (format string) {
return format return format
} }
func (f *unwrap) constructOrigFormat(verb rune) (format string) { func (w *Wrapper) constructOrigFormat(verb rune) string {
buf := bytes.NewBuffer(percentBytes) buf := bytes.NewBuffer(percentBytes)
for _, flag := range sf { for _, flag := range sf {
if f.s.Flag(int(flag)) { if w.s.Flag(int(flag)) {
_, _ = buf.WriteRune(flag) _, _ = buf.WriteRune(flag)
} }
} }
if width, ok := f.s.Width(); ok { if width, ok := w.s.Width(); ok {
_, _ = buf.WriteString(strconv.Itoa(width)) _, _ = buf.WriteString(strconv.Itoa(width))
} }
if precision, ok := f.s.Precision(); ok { if precision, ok := w.s.Precision(); ok {
_, _ = buf.Write(precisionBytes) _, _ = buf.Write(precisionBytes)
_, _ = buf.WriteString(strconv.Itoa(precision)) _, _ = buf.WriteString(strconv.Itoa(precision))
} }
_, _ = buf.WriteRune(verb) _, _ = buf.WriteRune(verb)
format = buf.String() return buf.String()
return format
} }
func (f *unwrap) checkTakeAll(v reflect.Value) bool { func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) {
takeAll := true if _, ok := w.takeAll[depth]; ok {
return
}
if !v.IsValid() || v.IsZero() {
return
}
switch v.Kind() { switch v.Kind() {
case reflect.Struct: case reflect.Struct:
break break
case reflect.Ptr: case reflect.Ptr:
v = v.Elem() v = v.Elem()
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
return true w.takeAll[depth] = true
return
} }
default: default:
return true w.takeAll[depth] = true
return
} }
vt := v.Type() vt := v.Type()
@ -630,10 +641,8 @@ func (f *unwrap) checkTakeAll(v reflect.Value) bool {
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
sv, ok := vt.Field(i).Tag.Lookup("logger") sv, ok := vt.Field(i).Tag.Lookup("logger")
if ok && sv == "take" { if ok && sv == "take" {
return false w.takeAll[depth] = false
} }
takeAll = f.checkTakeAll(v.Field(i)) w.checkTakeAll(v.Field(i), depth+1)
} }
return takeAll
} }