From 9b387312dab29d01ae1c0b84b674259da63e766c Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Mon, 6 Feb 2023 22:36:24 +0300 Subject: [PATCH] logger/unwrap: check nested in case of Tagged Signed-off-by: Vasiliy Tolstov --- logger/unwrap/unwrap.go | 271 +++++++++++++++++++++------------------- 1 file changed, 140 insertions(+), 131 deletions(-) diff --git a/logger/unwrap/unwrap.go b/logger/unwrap/unwrap.go index c5469b47..7d219e6e 100644 --- a/logger/unwrap/unwrap.go +++ b/logger/unwrap/unwrap.go @@ -46,14 +46,14 @@ var ( closeMapBytes = []byte("}") ) -type unwrap struct { +type Wrapper struct { val interface{} s fmt.State pointers map[uintptr]int opts *Options depth int ignoreNextType bool - takeAll bool + takeAll map[int]bool protoWrapperType bool sqlWrapperType bool } @@ -109,14 +109,14 @@ func Tagged(b bool) Option { } } -func Unwrap(val interface{}, opts ...Option) *unwrap { +func Unwrap(val interface{}, opts ...Option) *Wrapper { options := NewOptions(opts...) - return &unwrap{val: val, opts: &options, pointers: make(map[uintptr]int)} + return &Wrapper{val: val, opts: &options, pointers: make(map[uintptr]int), takeAll: make(map[int]bool)} } -func (f *unwrap) unpackValue(v reflect.Value) reflect.Value { +func (w *Wrapper) unpackValue(v reflect.Value) reflect.Value { if v.Kind() == reflect.Interface { - f.ignoreNextType = false + w.ignoreNextType = false if !v.IsNil() { v = v.Elem() } @@ -125,19 +125,19 @@ func (f *unwrap) unpackValue(v reflect.Value) reflect.Value { } // formatPtr handles formatting of pointers by indirecting them as necessary. -func (f *unwrap) formatPtr(v reflect.Value) { +func (w *Wrapper) formatPtr(v reflect.Value) { // Display nil if top level pointer is nil. - showTypes := f.s.Flag('#') - if v.IsNil() && (!showTypes || f.ignoreNextType) { - _, _ = f.s.Write(nilAngleBytes) + showTypes := w.s.Flag('#') + if v.IsNil() && (!showTypes || w.ignoreNextType) { + _, _ = w.s.Write(nilAngleBytes) return } // Remove pointers at or below the current depth from map used to detect // circular refs. - for k, depth := range f.pointers { - if depth >= f.depth { - delete(f.pointers, k) + for k, depth := range w.pointers { + if depth >= w.depth { + delete(w.pointers, k) } } @@ -159,12 +159,12 @@ func (f *unwrap) formatPtr(v reflect.Value) { indirects++ addr := ve.Pointer() pointerChain = append(pointerChain, addr) - if pd, ok := f.pointers[addr]; ok && pd < f.depth { + if pd, ok := w.pointers[addr]; ok && pd < w.depth { cycleFound = true indirects-- break } - f.pointers[addr] = f.depth + w.pointers[addr] = w.depth ve = ve.Elem() if ve.Kind() == reflect.Interface { @@ -177,49 +177,49 @@ func (f *unwrap) formatPtr(v reflect.Value) { } // Display type or indirection level depending on flags. - if showTypes && !f.ignoreNextType { - if f.depth > 0 { - _, _ = f.s.Write(openParenBytes) + if showTypes && !w.ignoreNextType { + if w.depth > 0 { + _, _ = w.s.Write(openParenBytes) } - if f.depth > 0 { - _, _ = f.s.Write(bytes.Repeat(asteriskBytes, indirects)) + if w.depth > 0 { + _, _ = w.s.Write(bytes.Repeat(asteriskBytes, indirects)) } else { - _, _ = f.s.Write(bytes.Repeat(ampBytes, indirects)) + _, _ = w.s.Write(bytes.Repeat(ampBytes, indirects)) } - _, _ = f.s.Write([]byte(ve.Type().String())) - if f.depth > 0 { - _, _ = f.s.Write(closeParenBytes) + _, _ = w.s.Write([]byte(ve.Type().String())) + if w.depth > 0 { + _, _ = w.s.Write(closeParenBytes) } } else { if nilFound || cycleFound { indirects += strings.Count(ve.Type().String(), "*") } - _, _ = f.s.Write(openAngleBytes) - _, _ = f.s.Write([]byte(strings.Repeat("*", indirects))) - _, _ = f.s.Write(closeAngleBytes) + _, _ = w.s.Write(openAngleBytes) + _, _ = w.s.Write([]byte(strings.Repeat("*", indirects))) + _, _ = w.s.Write(closeAngleBytes) } // Display pointer information depending on flags. - if f.s.Flag('+') && (len(pointerChain) > 0) { - _, _ = f.s.Write(openParenBytes) + if w.s.Flag('+') && (len(pointerChain) > 0) { + _, _ = w.s.Write(openParenBytes) for i, addr := range pointerChain { if i > 0 { - _, _ = f.s.Write(pointerChainBytes) + _, _ = w.s.Write(pointerChainBytes) } - getHexPtr(f.s, addr) + getHexPtr(w.s, addr) } - _, _ = f.s.Write(closeParenBytes) + _, _ = w.s.Write(closeParenBytes) } // Display dereferenced value. switch { case nilFound: - _, _ = f.s.Write(nilAngleBytes) + _, _ = w.s.Write(nilAngleBytes) case cycleFound: - _, _ = f.s.Write(circularShortBytes) + _, _ = w.s.Write(circularShortBytes) default: - f.ignoreNextType = true - f.format(ve) + w.ignoreNextType = true + w.format(ve) } } @@ -227,20 +227,24 @@ func (f *unwrap) formatPtr(v reflect.Value) { // uses the passed reflect value to figure out what kind of object we are // dealing with and formats it appropriately. It is a recursive function, // however circular data structures are detected and handled properly. -func (f *unwrap) format(v reflect.Value) { - if f.opts.Codec != nil { - buf, err := f.opts.Codec.Marshal(v.Interface()) +func (w *Wrapper) format(v reflect.Value) { + if w.opts.Codec != nil { + buf, err := w.opts.Codec.Marshal(v.Interface()) if err != nil { - _, _ = f.s.Write(invalidAngleBytes) + _, _ = w.s.Write(invalidAngleBytes) return } - _, _ = f.s.Write(buf) + _, _ = w.s.Write(buf) return } + if w.opts.Tagged { + w.checkTakeAll(v, 1) + } + // Handle invalid reflect values immediately. kind := v.Kind() if kind == reflect.Invalid { - _, _ = f.s.Write(invalidAngleBytes) + _, _ = w.s.Write(invalidAngleBytes) return } @@ -249,46 +253,46 @@ func (f *unwrap) format(v reflect.Value) { case reflect.Ptr: if !v.IsZero() { if strings.HasPrefix(reflect.Indirect(v).Type().String(), "wrapperspb.") { - f.protoWrapperType = true + w.protoWrapperType = true } else if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { - f.sqlWrapperType = true + w.sqlWrapperType = true } } - f.formatPtr(v) + w.formatPtr(v) return case reflect.Struct: if !v.IsZero() { if strings.HasPrefix(reflect.Indirect(v).Type().String(), "sql.Null") { - f.sqlWrapperType = true + w.sqlWrapperType = true } } } // get type information unless already handled elsewhere. - if !f.ignoreNextType && f.s.Flag('#') { + if !w.ignoreNextType && w.s.Flag('#') { if v.Type().Kind() != reflect.Map && v.Type().Kind() != reflect.String && v.Type().Kind() != reflect.Array && v.Type().Kind() != reflect.Slice { - _, _ = f.s.Write(openParenBytes) + _, _ = w.s.Write(openParenBytes) } if v.Kind() != reflect.String { - _, _ = f.s.Write([]byte(v.Type().String())) + _, _ = w.s.Write([]byte(v.Type().String())) } if v.Type().Kind() != reflect.Map && v.Type().Kind() != reflect.String && v.Type().Kind() != reflect.Array && v.Type().Kind() != reflect.Slice { - _, _ = f.s.Write(closeParenBytes) + _, _ = w.s.Write(closeParenBytes) } } - f.ignoreNextType = false + w.ignoreNextType = false // Call Stringer/error interfaces if they exist and the handle methods // flag is enabled. - if f.opts.Methods { + if w.opts.Methods { if (kind != reflect.Invalid) && (kind != reflect.Interface) { - if handled := handleMethods(f.opts, f.s, v); handled { + if handled := handleMethods(w.opts, w.s, v); handled { return } } @@ -296,48 +300,48 @@ func (f *unwrap) format(v reflect.Value) { switch kind { case reflect.Invalid: - _, _ = f.s.Write(invalidAngleBytes) + _, _ = w.s.Write(invalidAngleBytes) case reflect.Bool: - getBool(f.s, v.Bool()) + getBool(w.s, v.Bool()) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: - getInt(f.s, v.Int(), 10) + getInt(w.s, v.Int(), 10) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - getUint(f.s, v.Uint(), 10) + getUint(w.s, v.Uint(), 10) case reflect.Float32: - getFloat(f.s, v.Float(), 32) + getFloat(w.s, v.Float(), 32) case reflect.Float64: - getFloat(f.s, v.Float(), 64) + getFloat(w.s, v.Float(), 64) case reflect.Complex64: - getComplex(f.s, v.Complex(), 32) + getComplex(w.s, v.Complex(), 32) case reflect.Complex128: - getComplex(f.s, v.Complex(), 64) + getComplex(w.s, v.Complex(), 64) case reflect.Slice: if v.IsNil() { - _, _ = f.s.Write(nilAngleBytes) + _, _ = w.s.Write(nilAngleBytes) break } fallthrough case reflect.Array: - _, _ = f.s.Write(openBraceBytes) - f.depth++ + _, _ = w.s.Write(openBraceBytes) + w.depth++ numEntries := v.Len() for i := 0; i < numEntries; i++ { if i > 0 { - _, _ = f.s.Write(commaBytes) - _, _ = f.s.Write(spaceBytes) + _, _ = w.s.Write(commaBytes) + _, _ = w.s.Write(spaceBytes) } - f.ignoreNextType = true - f.format(f.unpackValue(v.Index(i))) + w.ignoreNextType = true + w.format(w.unpackValue(v.Index(i))) } - f.depth-- - _, _ = f.s.Write(closeBraceBytes) + w.depth-- + _, _ = w.s.Write(closeBraceBytes) case reflect.String: - _, _ = f.s.Write([]byte(`"` + v.String() + `"`)) + _, _ = w.s.Write([]byte(`"` + v.String() + `"`)) case reflect.Interface: // The only time we should get here is for nil interfaces due to // unpackValue calls. if v.IsNil() { - _, _ = f.s.Write(nilAngleBytes) + _, _ = w.s.Write(nilAngleBytes) } case reflect.Ptr: // Do nothing. We should never get here since pointers have already @@ -345,38 +349,39 @@ func (f *unwrap) format(v reflect.Value) { case reflect.Map: // nil maps should be indicated as different than empty maps if v.IsNil() { - _, _ = f.s.Write(nilAngleBytes) + _, _ = w.s.Write(nilAngleBytes) break } - _, _ = f.s.Write(openMapBytes) - f.depth++ + _, _ = w.s.Write(openMapBytes) + w.depth++ keys := v.MapKeys() for i, key := range keys { if i > 0 { - _, _ = f.s.Write(spaceBytes) + _, _ = w.s.Write(spaceBytes) } - f.ignoreNextType = true - f.format(f.unpackValue(key)) - _, _ = f.s.Write(colonBytes) - f.ignoreNextType = true - f.format(f.unpackValue(v.MapIndex(key))) + w.ignoreNextType = true + w.format(w.unpackValue(key)) + _, _ = w.s.Write(colonBytes) + w.ignoreNextType = true + w.format(w.unpackValue(v.MapIndex(key))) } - f.depth-- - _, _ = f.s.Write(closeMapBytes) + w.depth-- + _, _ = w.s.Write(closeMapBytes) case reflect.Struct: + numFields := v.NumField() numWritten := 0 - _, _ = f.s.Write(openBraceBytes) - f.depth++ + _, _ = w.s.Write(openBraceBytes) + w.depth++ + vt := v.Type() prevSkip := false for i := 0; i < numFields; i++ { - f.takeAll = false - if f.protoWrapperType && !vt.Field(i).IsExported() { + if w.protoWrapperType && !vt.Field(i).IsExported() { prevSkip = true continue - } else if f.sqlWrapperType && vt.Field(i).Name == "Valid" { + } else if w.sqlWrapperType && vt.Field(i).Name == "Valid" { prevSkip = true continue } @@ -390,9 +395,9 @@ func (f *unwrap) format(v reflect.Value) { case "take": break } - case f.takeAll: + case w.takeAll[w.depth]: break - case !ok && f.opts.Tagged: + case !ok && w.opts.Tagged: prevSkip = true continue } @@ -402,52 +407,53 @@ func (f *unwrap) format(v reflect.Value) { } if numWritten > 0 { - _, _ = f.s.Write(commaBytes) - _, _ = f.s.Write(spaceBytes) + _, _ = w.s.Write(commaBytes) + _, _ = w.s.Write(spaceBytes) } - vtf := vt.Field(i) - if f.s.Flag('+') || f.s.Flag('#') { - _, _ = f.s.Write([]byte(vtf.Name)) - _, _ = f.s.Write(colonBytes) + vt := vt.Field(i) + if w.s.Flag('+') || w.s.Flag('#') { + _, _ = w.s.Write([]byte(vt.Name)) + _, _ = w.s.Write(colonBytes) } - unpackValue := f.unpackValue(v.Field(i)) - f.takeAll = f.checkTakeAll(unpackValue) - f.format(unpackValue) + unpackValue := w.unpackValue(v.Field(i)) + w.checkTakeAll(unpackValue, w.depth) + w.format(unpackValue) numWritten++ } - f.depth-- - if numWritten == 0 && f.depth < 0 { - _, _ = f.s.Write(filteredBytes) + w.depth-- + + if numWritten == 0 && w.depth < 0 { + _, _ = w.s.Write(filteredBytes) } - _, _ = f.s.Write(closeBraceBytes) + _, _ = w.s.Write(closeBraceBytes) case reflect.Uintptr: - getHexPtr(f.s, uintptr(v.Uint())) + getHexPtr(w.s, uintptr(v.Uint())) case reflect.UnsafePointer, reflect.Chan, reflect.Func: - getHexPtr(f.s, v.Pointer()) + getHexPtr(w.s, v.Pointer()) // There were not any other types at the time this code was written, but // fall back to letting the default fmt package handle it if any get added. default: - format := f.buildDefaultFormat() + format := w.buildDefaultFormat() if v.CanInterface() { - _, _ = fmt.Fprintf(f.s, format, v.Interface()) + _, _ = fmt.Fprintf(w.s, format, v.Interface()) } else { - _, _ = fmt.Fprintf(f.s, format, v.String()) + _, _ = fmt.Fprintf(w.s, format, v.String()) } } } -func (f *unwrap) Format(s fmt.State, verb rune) { - f.s = s +func (w *Wrapper) Format(s fmt.State, verb rune) { + w.s = s // Use standard formatting for verbs that are not v. if verb != 'v' { - format := f.constructOrigFormat(verb) - _, _ = fmt.Fprintf(s, format, f.val) + format := w.constructOrigFormat(verb) + _, _ = fmt.Fprintf(s, format, w.val) return } - if f.val == nil { + if w.val == nil { if s.Flag('#') { _, _ = s.Write(interfaceBytes) } @@ -455,7 +461,7 @@ func (f *unwrap) Format(s fmt.State, verb rune) { return } - f.format(reflect.ValueOf(f.val)) + w.format(reflect.ValueOf(w.val)) } // handle special methods like error.Error() or fmt.Stringer interface @@ -571,11 +577,11 @@ func catchPanic(w io.Writer, _ reflect.Value) { } } -func (f *unwrap) buildDefaultFormat() (format string) { +func (w *Wrapper) buildDefaultFormat() (format string) { buf := bytes.NewBuffer(percentBytes) for _, flag := range sf { - if f.s.Flag(int(flag)) { + if w.s.Flag(int(flag)) { _, _ = buf.WriteRune(flag) } } @@ -586,43 +592,48 @@ func (f *unwrap) buildDefaultFormat() (format string) { return format } -func (f *unwrap) constructOrigFormat(verb rune) (format string) { +func (w *Wrapper) constructOrigFormat(verb rune) string { buf := bytes.NewBuffer(percentBytes) for _, flag := range sf { - if f.s.Flag(int(flag)) { + if w.s.Flag(int(flag)) { _, _ = buf.WriteRune(flag) } } - if width, ok := f.s.Width(); ok { + if width, ok := w.s.Width(); ok { _, _ = buf.WriteString(strconv.Itoa(width)) } - if precision, ok := f.s.Precision(); ok { + if precision, ok := w.s.Precision(); ok { _, _ = buf.Write(precisionBytes) _, _ = buf.WriteString(strconv.Itoa(precision)) } _, _ = buf.WriteRune(verb) - format = buf.String() - return format + return buf.String() } -func (f *unwrap) checkTakeAll(v reflect.Value) bool { - takeAll := true - +func (w *Wrapper) checkTakeAll(v reflect.Value, depth int) { + if _, ok := w.takeAll[depth]; ok { + return + } + if !v.IsValid() || v.IsZero() { + return + } switch v.Kind() { case reflect.Struct: break case reflect.Ptr: v = v.Elem() if v.Kind() != reflect.Struct { - return true + w.takeAll[depth] = true + return } default: - return true + w.takeAll[depth] = true + return } vt := v.Type() @@ -630,10 +641,8 @@ func (f *unwrap) checkTakeAll(v reflect.Value) bool { for i := 0; i < v.NumField(); i++ { sv, ok := vt.Field(i).Tag.Lookup("logger") if ok && sv == "take" { - return false + w.takeAll[depth] = false } - takeAll = f.checkTakeAll(v.Field(i)) + w.checkTakeAll(v.Field(i), depth+1) } - - return takeAll }