diff --git a/zaptest/observer/observer.go b/zaptest/observer/observer.go index 6ae58f5d6..f3f6116b6 100644 --- a/zaptest/observer/observer.go +++ b/zaptest/observer/observer.go @@ -78,23 +78,30 @@ func (o *ObservedLogs) AllUntimed() []LoggedEntry { return ret } +// FilterLevelExact filters entries to those logged at exactly the given level. +func (o *ObservedLogs) FilterLevelExact(level zapcore.Level) *ObservedLogs { + return o.Filter(func(e LoggedEntry) bool { + return e.Level == level + }) +} + // FilterMessage filters entries to those that have the specified message. func (o *ObservedLogs) FilterMessage(msg string) *ObservedLogs { - return o.filter(func(e LoggedEntry) bool { + return o.Filter(func(e LoggedEntry) bool { return e.Message == msg }) } // FilterMessageSnippet filters entries to those that have a message containing the specified snippet. func (o *ObservedLogs) FilterMessageSnippet(snippet string) *ObservedLogs { - return o.filter(func(e LoggedEntry) bool { + return o.Filter(func(e LoggedEntry) bool { return strings.Contains(e.Message, snippet) }) } // FilterField filters entries to those that have the specified field. func (o *ObservedLogs) FilterField(field zapcore.Field) *ObservedLogs { - return o.filter(func(e LoggedEntry) bool { + return o.Filter(func(e LoggedEntry) bool { for _, ctxField := range e.Context { if ctxField.Equals(field) { return true @@ -106,7 +113,7 @@ func (o *ObservedLogs) FilterField(field zapcore.Field) *ObservedLogs { // FilterFieldKey filters entries to those that have the specified key. func (o *ObservedLogs) FilterFieldKey(key string) *ObservedLogs { - return o.filter(func(e LoggedEntry) bool { + return o.Filter(func(e LoggedEntry) bool { for _, ctxField := range e.Context { if ctxField.Key == key { return true @@ -116,13 +123,15 @@ func (o *ObservedLogs) FilterFieldKey(key string) *ObservedLogs { }) } -func (o *ObservedLogs) filter(match func(LoggedEntry) bool) *ObservedLogs { +// Filter returns a copy of this ObservedLogs containing only those entries +// for which the provided function returns true. +func (o *ObservedLogs) Filter(keep func(LoggedEntry) bool) *ObservedLogs { o.mu.RLock() defer o.mu.RUnlock() var filtered []LoggedEntry for _, entry := range o.logs { - if match(entry) { + if keep(entry) { filtered = append(filtered, entry) } } diff --git a/zaptest/observer/observer_test.go b/zaptest/observer/observer_test.go index 4e8139493..9f179d02e 100644 --- a/zaptest/observer/observer_test.go +++ b/zaptest/observer/observer_test.go @@ -55,7 +55,7 @@ func TestObserver(t *testing.T) { assert.Equal(t, want, logs.AllUntimed(), "Unexpected contents from AllUntimed.") all := logs.All() - require.Equal(t, 1, len(all), "Unexpected numbed of LoggedEntries returned from All.") + require.Equal(t, 1, len(all), "Unexpected number of LoggedEntries returned from All.") assert.NotEqual(t, time.Time{}, all[0].Time, "Expected non-zero time on LoggedEntry.") // copy & zero time for stable assertions @@ -157,6 +157,14 @@ func TestFilters(t *testing.T) { Entry: zapcore.Entry{Level: zap.InfoLevel, Message: "any slice"}, Context: []zapcore.Field{zap.Any("filterMe", []string{"b"})}, }, + { + Entry: zapcore.Entry{Level: zap.WarnLevel, Message: "danger will robinson"}, + Context: []zapcore.Field{zap.Int("b", 42)}, + }, + { + Entry: zapcore.Entry{Level: zap.ErrorLevel, Message: "warp core breach"}, + Context: []zapcore.Field{zap.Int("b", 42)}, + }, } logger, sink := New(zap.InfoLevel) @@ -219,6 +227,24 @@ func TestFilters(t *testing.T) { filtered: sink.FilterFieldKey("filterMe"), want: logs[7:9], }, + { + msg: "filter by arbitrary function", + filtered: sink.Filter(func(e LoggedEntry) bool { + return len(e.Context) > 1 + }), + want: func() []LoggedEntry { + // Do not modify logs slice. + w := make([]LoggedEntry, 0, len(logs)) + w = append(w, logs[0:5]...) + w = append(w, logs[7]) + return w + }(), + }, + { + msg: "filter level", + filtered: sink.FilterLevelExact(zap.WarnLevel), + want: logs[9:10], + }, } for _, tt := range tests {