356 lines
10 KiB
Go
356 lines
10 KiB
Go
/*Package cmp provides Comparisons for Assert and Check*/
|
|
package cmp // import "gotest.tools/assert/cmp"
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"gotest.tools/internal/format"
|
|
)
|
|
|
|
// Comparison is a function which compares values and returns ResultSuccess if
|
|
// the actual value matches the expected value. If the values do not match the
|
|
// Result will contain a message about why it failed.
|
|
type Comparison func() Result
|
|
|
|
// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
|
|
// and succeeds if the values are equal.
|
|
//
|
|
// The comparison can be customized using comparison Options.
|
|
// Package https://godoc.org/gotest.tools/assert/opt provides some additional
|
|
// commonly used Options.
|
|
func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
|
|
return func() (result Result) {
|
|
defer func() {
|
|
if panicmsg, handled := handleCmpPanic(recover()); handled {
|
|
result = ResultFailure(panicmsg)
|
|
}
|
|
}()
|
|
diff := cmp.Diff(x, y, opts...)
|
|
if diff == "" {
|
|
return ResultSuccess
|
|
}
|
|
return multiLineDiffResult(diff)
|
|
}
|
|
}
|
|
|
|
func handleCmpPanic(r interface{}) (string, bool) {
|
|
if r == nil {
|
|
return "", false
|
|
}
|
|
panicmsg, ok := r.(string)
|
|
if !ok {
|
|
panic(r)
|
|
}
|
|
switch {
|
|
case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
|
|
return panicmsg, true
|
|
}
|
|
panic(r)
|
|
}
|
|
|
|
func toResult(success bool, msg string) Result {
|
|
if success {
|
|
return ResultSuccess
|
|
}
|
|
return ResultFailure(msg)
|
|
}
|
|
|
|
// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
|
|
// regexp pattern.
|
|
type RegexOrPattern interface{}
|
|
|
|
// Regexp succeeds if value v matches regular expression re.
|
|
//
|
|
// Example:
|
|
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
|
|
// r := regexp.MustCompile("^[0-9a-f]{32}$")
|
|
// assert.Assert(t, cmp.Regexp(r, str))
|
|
func Regexp(re RegexOrPattern, v string) Comparison {
|
|
match := func(re *regexp.Regexp) Result {
|
|
return toResult(
|
|
re.MatchString(v),
|
|
fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
|
|
}
|
|
|
|
return func() Result {
|
|
switch regex := re.(type) {
|
|
case *regexp.Regexp:
|
|
return match(regex)
|
|
case string:
|
|
re, err := regexp.Compile(regex)
|
|
if err != nil {
|
|
return ResultFailure(err.Error())
|
|
}
|
|
return match(re)
|
|
default:
|
|
return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Equal succeeds if x == y. See assert.Equal for full documentation.
|
|
func Equal(x, y interface{}) Comparison {
|
|
return func() Result {
|
|
switch {
|
|
case x == y:
|
|
return ResultSuccess
|
|
case isMultiLineStringCompare(x, y):
|
|
diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
|
|
return multiLineDiffResult(diff)
|
|
}
|
|
return ResultFailureTemplate(`
|
|
{{- .Data.x}} (
|
|
{{- with callArg 0 }}{{ formatNode . }} {{end -}}
|
|
{{- printf "%T" .Data.x -}}
|
|
) != {{ .Data.y}} (
|
|
{{- with callArg 1 }}{{ formatNode . }} {{end -}}
|
|
{{- printf "%T" .Data.y -}}
|
|
)`,
|
|
map[string]interface{}{"x": x, "y": y})
|
|
}
|
|
}
|
|
|
|
func isMultiLineStringCompare(x, y interface{}) bool {
|
|
strX, ok := x.(string)
|
|
if !ok {
|
|
return false
|
|
}
|
|
strY, ok := y.(string)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
|
|
}
|
|
|
|
func multiLineDiffResult(diff string) Result {
|
|
return ResultFailureTemplate(`
|
|
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
|
|
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
|
|
{{ .Data.diff }}`,
|
|
map[string]interface{}{"diff": diff})
|
|
}
|
|
|
|
// Len succeeds if the sequence has the expected length.
|
|
func Len(seq interface{}, expected int) Comparison {
|
|
return func() (result Result) {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
|
|
}
|
|
}()
|
|
value := reflect.ValueOf(seq)
|
|
length := value.Len()
|
|
if length == expected {
|
|
return ResultSuccess
|
|
}
|
|
msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
|
|
return ResultFailure(msg)
|
|
}
|
|
}
|
|
|
|
// Contains succeeds if item is in collection. Collection may be a string, map,
|
|
// slice, or array.
|
|
//
|
|
// If collection is a string, item must also be a string, and is compared using
|
|
// strings.Contains().
|
|
// If collection is a Map, contains will succeed if item is a key in the map.
|
|
// If collection is a slice or array, item is compared to each item in the
|
|
// sequence using reflect.DeepEqual().
|
|
func Contains(collection interface{}, item interface{}) Comparison {
|
|
return func() Result {
|
|
colValue := reflect.ValueOf(collection)
|
|
if !colValue.IsValid() {
|
|
return ResultFailure(fmt.Sprintf("nil does not contain items"))
|
|
}
|
|
msg := fmt.Sprintf("%v does not contain %v", collection, item)
|
|
|
|
itemValue := reflect.ValueOf(item)
|
|
switch colValue.Type().Kind() {
|
|
case reflect.String:
|
|
if itemValue.Type().Kind() != reflect.String {
|
|
return ResultFailure("string may only contain strings")
|
|
}
|
|
return toResult(
|
|
strings.Contains(colValue.String(), itemValue.String()),
|
|
fmt.Sprintf("string %q does not contain %q", collection, item))
|
|
|
|
case reflect.Map:
|
|
if itemValue.Type() != colValue.Type().Key() {
|
|
return ResultFailure(fmt.Sprintf(
|
|
"%v can not contain a %v key", colValue.Type(), itemValue.Type()))
|
|
}
|
|
return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
for i := 0; i < colValue.Len(); i++ {
|
|
if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
|
|
return ResultSuccess
|
|
}
|
|
}
|
|
return ResultFailure(msg)
|
|
default:
|
|
return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Panics succeeds if f() panics.
|
|
func Panics(f func()) Comparison {
|
|
return func() (result Result) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
result = ResultSuccess
|
|
}
|
|
}()
|
|
f()
|
|
return ResultFailure("did not panic")
|
|
}
|
|
}
|
|
|
|
// Error succeeds if err is a non-nil error, and the error message equals the
|
|
// expected message.
|
|
func Error(err error, message string) Comparison {
|
|
return func() Result {
|
|
switch {
|
|
case err == nil:
|
|
return ResultFailure("expected an error, got nil")
|
|
case err.Error() != message:
|
|
return ResultFailure(fmt.Sprintf(
|
|
"expected error %q, got %s", message, formatErrorMessage(err)))
|
|
}
|
|
return ResultSuccess
|
|
}
|
|
}
|
|
|
|
// ErrorContains succeeds if err is a non-nil error, and the error message contains
|
|
// the expected substring.
|
|
func ErrorContains(err error, substring string) Comparison {
|
|
return func() Result {
|
|
switch {
|
|
case err == nil:
|
|
return ResultFailure("expected an error, got nil")
|
|
case !strings.Contains(err.Error(), substring):
|
|
return ResultFailure(fmt.Sprintf(
|
|
"expected error to contain %q, got %s", substring, formatErrorMessage(err)))
|
|
}
|
|
return ResultSuccess
|
|
}
|
|
}
|
|
|
|
func formatErrorMessage(err error) string {
|
|
if _, ok := err.(interface {
|
|
Cause() error
|
|
}); ok {
|
|
return fmt.Sprintf("%q\n%+v", err, err)
|
|
}
|
|
// This error was not wrapped with github.com/pkg/errors
|
|
return fmt.Sprintf("%q", err)
|
|
}
|
|
|
|
// Nil succeeds if obj is a nil interface, pointer, or function.
|
|
//
|
|
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
|
|
// maps, and channels.
|
|
func Nil(obj interface{}) Comparison {
|
|
msgFunc := func(value reflect.Value) string {
|
|
return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
|
|
}
|
|
return isNil(obj, msgFunc)
|
|
}
|
|
|
|
func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
|
|
return func() Result {
|
|
if obj == nil {
|
|
return ResultSuccess
|
|
}
|
|
value := reflect.ValueOf(obj)
|
|
kind := value.Type().Kind()
|
|
if kind >= reflect.Chan && kind <= reflect.Slice {
|
|
if value.IsNil() {
|
|
return ResultSuccess
|
|
}
|
|
return ResultFailure(msgFunc(value))
|
|
}
|
|
|
|
return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
|
|
}
|
|
}
|
|
|
|
// ErrorType succeeds if err is not nil and is of the expected type.
|
|
//
|
|
// Expected can be one of:
|
|
// a func(error) bool which returns true if the error is the expected type,
|
|
// an instance of (or a pointer to) a struct of the expected type,
|
|
// a pointer to an interface the error is expected to implement,
|
|
// a reflect.Type of the expected struct or interface.
|
|
func ErrorType(err error, expected interface{}) Comparison {
|
|
return func() Result {
|
|
switch expectedType := expected.(type) {
|
|
case func(error) bool:
|
|
return cmpErrorTypeFunc(err, expectedType)
|
|
case reflect.Type:
|
|
if expectedType.Kind() == reflect.Interface {
|
|
return cmpErrorTypeImplementsType(err, expectedType)
|
|
}
|
|
return cmpErrorTypeEqualType(err, expectedType)
|
|
case nil:
|
|
return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
|
|
}
|
|
|
|
expectedType := reflect.TypeOf(expected)
|
|
switch {
|
|
case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
|
|
return cmpErrorTypeEqualType(err, expectedType)
|
|
case isPtrToInterface(expectedType):
|
|
return cmpErrorTypeImplementsType(err, expectedType.Elem())
|
|
}
|
|
return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
|
|
}
|
|
}
|
|
|
|
func cmpErrorTypeFunc(err error, f func(error) bool) Result {
|
|
if f(err) {
|
|
return ResultSuccess
|
|
}
|
|
actual := "nil"
|
|
if err != nil {
|
|
actual = fmt.Sprintf("%s (%T)", err, err)
|
|
}
|
|
return ResultFailureTemplate(`error is {{ .Data.actual }}
|
|
{{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
|
|
map[string]interface{}{"actual": actual})
|
|
}
|
|
|
|
func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
|
|
if err == nil {
|
|
return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
|
|
}
|
|
errValue := reflect.ValueOf(err)
|
|
if errValue.Type() == expectedType {
|
|
return ResultSuccess
|
|
}
|
|
return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
|
|
}
|
|
|
|
func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
|
|
if err == nil {
|
|
return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
|
|
}
|
|
errValue := reflect.ValueOf(err)
|
|
if errValue.Type().Implements(expectedType) {
|
|
return ResultSuccess
|
|
}
|
|
return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
|
|
}
|
|
|
|
func isPtrToInterface(typ reflect.Type) bool {
|
|
return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
|
|
}
|
|
|
|
func isPtrToStruct(typ reflect.Type) bool {
|
|
return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
|
|
}
|