k9s/vendor/github.com/petergtz/pegomock/dsl.go

565 lines
17 KiB
Go

// Copyright 2015 Peter Goetz
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pegomock
import (
"bytes"
"fmt"
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/onsi/gomega/format"
"github.com/petergtz/pegomock/internal/verify"
)
var GlobalFailHandler FailHandler
func RegisterMockFailHandler(handler FailHandler) {
GlobalFailHandler = handler
}
func RegisterMockTestingT(t *testing.T) {
RegisterMockFailHandler(BuildTestingTGomegaFailHandler(t))
}
var (
lastInvocation *invocation
lastInvocationMutex sync.Mutex
)
var globalArgMatchers Matchers
func RegisterMatcher(matcher Matcher) {
globalArgMatchers.append(matcher)
}
type invocation struct {
genericMock *GenericMock
MethodName string
Params []Param
ReturnTypes []reflect.Type
}
type GenericMock struct {
sync.Mutex
mockedMethods map[string]*mockedMethod
}
func (genericMock *GenericMock) Invoke(methodName string, params []Param, returnTypes []reflect.Type) ReturnValues {
lastInvocationMutex.Lock()
lastInvocation = &invocation{
genericMock: genericMock,
MethodName: methodName,
Params: params,
ReturnTypes: returnTypes,
}
lastInvocationMutex.Unlock()
return genericMock.getOrCreateMockedMethod(methodName).Invoke(params)
}
func (genericMock *GenericMock) stub(methodName string, paramMatchers []Matcher, returnValues ReturnValues) {
genericMock.stubWithCallback(methodName, paramMatchers, func([]Param) ReturnValues { return returnValues })
}
func (genericMock *GenericMock) stubWithCallback(methodName string, paramMatchers []Matcher, callback func([]Param) ReturnValues) {
genericMock.getOrCreateMockedMethod(methodName).stub(paramMatchers, callback)
}
func (genericMock *GenericMock) getOrCreateMockedMethod(methodName string) *mockedMethod {
genericMock.Lock()
defer genericMock.Unlock()
if _, ok := genericMock.mockedMethods[methodName]; !ok {
genericMock.mockedMethods[methodName] = &mockedMethod{name: methodName}
}
return genericMock.mockedMethods[methodName]
}
func (genericMock *GenericMock) reset(methodName string, paramMatchers []Matcher) {
genericMock.getOrCreateMockedMethod(methodName).reset(paramMatchers)
}
func (genericMock *GenericMock) Verify(
inOrderContext *InOrderContext,
invocationCountMatcher Matcher,
methodName string,
params []Param,
options ...interface{},
) []MethodInvocation {
var timeout time.Duration
if len(options) == 1 {
timeout = options[0].(time.Duration)
}
if GlobalFailHandler == nil {
panic("No GlobalFailHandler set. Please use either RegisterMockFailHandler or RegisterMockTestingT to set a fail handler.")
}
defer func() { globalArgMatchers = nil }() // We don't want a panic somewhere during verification screw our global argMatchers
if len(globalArgMatchers) != 0 {
verifyArgMatcherUse(globalArgMatchers, params)
}
startTime := time.Now()
// timeoutLoop:
for {
genericMock.Lock()
methodInvocations := genericMock.methodInvocations(methodName, params, globalArgMatchers)
genericMock.Unlock()
if inOrderContext != nil {
for _, methodInvocation := range methodInvocations {
if methodInvocation.orderingInvocationNumber <= inOrderContext.invocationCounter {
// TODO: should introduce the following, in case we decide support "inorder" and "eventually"
// if time.Since(startTime) < timeout {
// continue timeoutLoop
// }
GlobalFailHandler(fmt.Sprintf("Expected function call %v(%v) before function call %v(%v)",
methodName, formatParams(params), inOrderContext.lastInvokedMethodName, formatParams(inOrderContext.lastInvokedMethodParams)))
}
inOrderContext.invocationCounter = methodInvocation.orderingInvocationNumber
inOrderContext.lastInvokedMethodName = methodName
inOrderContext.lastInvokedMethodParams = params
}
}
if !invocationCountMatcher.Matches(len(methodInvocations)) {
if time.Since(startTime) < timeout {
time.Sleep(10 * time.Millisecond)
continue
}
var paramsOrMatchers interface{} = formatParams(params)
if len(globalArgMatchers) != 0 {
paramsOrMatchers = formatMatchers(globalArgMatchers)
}
timeoutInfo := ""
if timeout > 0 {
timeoutInfo = fmt.Sprintf(" after timeout of %v", timeout)
}
GlobalFailHandler(fmt.Sprintf(
"Mock invocation count for %v(%v) does not match expectation%v.\n\n\t%v\n\n\t%v",
methodName, paramsOrMatchers, timeoutInfo, invocationCountMatcher.FailureMessage(), formatInteractions(genericMock.allInteractions())))
}
return methodInvocations
}
}
// TODO this doesn't need to be a method, can be a free function
func (genericMock *GenericMock) GetInvocationParams(methodInvocations []MethodInvocation) [][]Param {
if len(methodInvocations) == 0 {
return nil
}
result := make([][]Param, len(methodInvocations[len(methodInvocations)-1].params))
for i, invocation := range methodInvocations {
for u, param := range invocation.params {
if result[u] == nil {
result[u] = make([]Param, len(methodInvocations))
}
result[u][i] = param
}
}
return result
}
func (genericMock *GenericMock) methodInvocations(methodName string, params []Param, matchers []Matcher) []MethodInvocation {
var invocations []MethodInvocation
if method, exists := genericMock.mockedMethods[methodName]; exists {
method.Lock()
for _, invocation := range method.invocations {
if len(matchers) != 0 {
if Matchers(matchers).Matches(invocation.params) {
invocations = append(invocations, invocation)
}
} else {
if reflect.DeepEqual(params, invocation.params) ||
(len(params) == 0 && len(invocation.params) == 0) {
invocations = append(invocations, invocation)
}
}
}
method.Unlock()
}
return invocations
}
func formatInteractions(interactions map[string][]MethodInvocation) string {
if len(interactions) == 0 {
return "There were no other interactions with this mock"
}
result := "But other interactions with this mock were:\n"
for _, methodName := range sortedMethodNames(interactions) {
result += formatInvocations(methodName, interactions[methodName])
}
return result
}
func formatInvocations(methodName string, invocations []MethodInvocation) (result string) {
for _, invocation := range invocations {
result += "\t" + methodName + "(" + formatParams(invocation.params) + ")\n"
}
return
}
func formatParams(params []Param) (result string) {
for i, param := range params {
if i > 0 {
result += ", "
}
result += fmt.Sprintf("%#v", param)
}
return
}
func formatMatchers(matchers []Matcher) (result string) {
for i, matcher := range matchers {
if i > 0 {
result += ", "
}
result += fmt.Sprintf("%v", matcher)
}
return
}
func sortedMethodNames(interactions map[string][]MethodInvocation) []string {
methodNames := make([]string, len(interactions))
i := 0
for key := range interactions {
methodNames[i] = key
i++
}
sort.Strings(methodNames)
return methodNames
}
func (genericMock *GenericMock) allInteractions() map[string][]MethodInvocation {
interactions := make(map[string][]MethodInvocation)
for methodName := range genericMock.mockedMethods {
for _, invocation := range genericMock.mockedMethods[methodName].invocations {
interactions[methodName] = append(interactions[methodName], invocation)
}
}
return interactions
}
type mockedMethod struct {
sync.Mutex
name string
invocations []MethodInvocation
stubbings Stubbings
}
func (method *mockedMethod) Invoke(params []Param) ReturnValues {
method.Lock()
method.invocations = append(method.invocations, MethodInvocation{params, globalInvocationCounter.nextNumber()})
method.Unlock()
stubbing := method.stubbings.find(params)
if stubbing == nil {
return ReturnValues{}
}
return stubbing.Invoke(params)
}
func (method *mockedMethod) stub(paramMatchers Matchers, callback func([]Param) ReturnValues) {
stubbing := method.stubbings.findByMatchers(paramMatchers)
if stubbing == nil {
stubbing = &Stubbing{paramMatchers: paramMatchers}
method.stubbings = append(method.stubbings, stubbing)
}
stubbing.callbackSequence = append(stubbing.callbackSequence, callback)
}
func (method *mockedMethod) removeLastInvocation() {
method.invocations = method.invocations[:len(method.invocations)-1]
}
func (method *mockedMethod) reset(paramMatchers Matchers) {
method.stubbings.removeByMatchers(paramMatchers)
}
type Counter struct {
count int
sync.Mutex
}
func (counter *Counter) nextNumber() (nextNumber int) {
counter.Lock()
defer counter.Unlock()
nextNumber = counter.count
counter.count++
return
}
var globalInvocationCounter = Counter{count: 1}
type MethodInvocation struct {
params []Param
orderingInvocationNumber int
}
type Stubbings []*Stubbing
func (stubbings Stubbings) find(params []Param) *Stubbing {
for i := len(stubbings) - 1; i >= 0; i-- {
if stubbings[i].paramMatchers.Matches(params) {
return stubbings[i]
}
}
return nil
}
func (stubbings Stubbings) findByMatchers(paramMatchers Matchers) *Stubbing {
for _, stubbing := range stubbings {
if matchersEqual(stubbing.paramMatchers, paramMatchers) {
return stubbing
}
}
return nil
}
func (stubbings *Stubbings) removeByMatchers(paramMatchers Matchers) {
for i, stubbing := range *stubbings {
if matchersEqual(stubbing.paramMatchers, paramMatchers) {
*stubbings = append((*stubbings)[:i], (*stubbings)[i+1:]...)
}
}
}
func matchersEqual(a, b Matchers) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !reflect.DeepEqual(a[i], b[i]) {
return false
}
}
return true
}
type Stubbing struct {
paramMatchers Matchers
callbackSequence []func([]Param) ReturnValues
sequencePointer int
}
func (stubbing *Stubbing) Invoke(params []Param) ReturnValues {
defer func() {
if stubbing.sequencePointer < len(stubbing.callbackSequence)-1 {
stubbing.sequencePointer++
}
}()
return stubbing.callbackSequence[stubbing.sequencePointer](params)
}
type Matchers []Matcher
func (matchers Matchers) Matches(params []Param) bool {
if len(matchers) != len(params) { // Technically, this is not an error. Variadic arguments can cause this
return false
}
for i := range params {
if !matchers[i].Matches(params[i]) {
return false
}
}
return true
}
func (matchers *Matchers) append(matcher Matcher) {
*matchers = append(*matchers, matcher)
}
type ongoingStubbing struct {
genericMock *GenericMock
MethodName string
ParamMatchers []Matcher
returnTypes []reflect.Type
}
func When(invocation ...interface{}) *ongoingStubbing {
callIfIsFunc(invocation)
verify.Argument(lastInvocation != nil,
"When() requires an argument which has to be 'a method call on a mock'.")
defer func() {
lastInvocationMutex.Lock()
lastInvocation = nil
lastInvocationMutex.Unlock()
globalArgMatchers = nil
}()
lastInvocation.genericMock.mockedMethods[lastInvocation.MethodName].removeLastInvocation()
paramMatchers := paramMatchersFromArgMatchersOrParams(globalArgMatchers, lastInvocation.Params)
lastInvocation.genericMock.reset(lastInvocation.MethodName, paramMatchers)
return &ongoingStubbing{
genericMock: lastInvocation.genericMock,
MethodName: lastInvocation.MethodName,
ParamMatchers: paramMatchers,
returnTypes: lastInvocation.ReturnTypes,
}
}
func callIfIsFunc(invocation []interface{}) {
if len(invocation) == 1 {
actualType := actualTypeOf(invocation[0])
if actualType != nil && actualType.Kind() == reflect.Func && !reflect.ValueOf(invocation[0]).IsNil() {
if !(actualType.NumIn() == 0 && actualType.NumOut() == 0) {
panic("When using 'When' with function that does not return a value, " +
"it expects a function with no arguments and no return value.")
}
reflect.ValueOf(invocation[0]).Call([]reflect.Value{})
}
}
}
// Deals with nils without panicking
func actualTypeOf(iface interface{}) reflect.Type {
defer func() { recover() }()
return reflect.TypeOf(iface)
}
func paramMatchersFromArgMatchersOrParams(argMatchers []Matcher, params []Param) []Matcher {
if len(argMatchers) != 0 {
verifyArgMatcherUse(argMatchers, params)
return argMatchers
}
return transformParamsIntoEqMatchers(params)
}
func verifyArgMatcherUse(argMatchers []Matcher, params []Param) {
verify.Argument(len(argMatchers) == len(params),
"Invalid use of matchers!\n\n %v matchers expected, %v recorded.\n\n"+
"This error may occur if matchers are combined with raw values:\n"+
" //incorrect:\n"+
" someFunc(AnyInt(), \"raw String\")\n"+
"When using matchers, all arguments have to be provided by matchers.\n"+
"For example:\n"+
" //correct:\n"+
" someFunc(AnyInt(), EqString(\"String by matcher\"))",
len(params), len(argMatchers),
)
}
func transformParamsIntoEqMatchers(params []Param) []Matcher {
paramMatchers := make([]Matcher, len(params))
for i, param := range params {
paramMatchers[i] = &EqMatcher{Value: param}
}
return paramMatchers
}
var (
genericMocksMutex sync.Mutex
genericMocks = make(map[Mock]*GenericMock)
)
func GetGenericMockFrom(mock Mock) *GenericMock {
genericMocksMutex.Lock()
defer genericMocksMutex.Unlock()
if genericMocks[mock] == nil {
genericMocks[mock] = &GenericMock{mockedMethods: make(map[string]*mockedMethod)}
}
return genericMocks[mock]
}
func (stubbing *ongoingStubbing) ThenReturn(values ...ReturnValue) *ongoingStubbing {
checkAssignabilityOf(values, stubbing.returnTypes)
stubbing.genericMock.stub(stubbing.MethodName, stubbing.ParamMatchers, values)
return stubbing
}
func checkAssignabilityOf(stubbedReturnValues []ReturnValue, expectedReturnTypes []reflect.Type) {
verify.Argument(len(stubbedReturnValues) == len(expectedReturnTypes),
"Different number of return values")
for i := range stubbedReturnValues {
if stubbedReturnValues[i] == nil {
switch expectedReturnTypes[i].Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32,
reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.String,
reflect.Struct:
panic("Return value 'nil' not assignable to return type " + expectedReturnTypes[i].Kind().String())
}
} else {
verify.Argument(reflect.TypeOf(stubbedReturnValues[i]).AssignableTo(expectedReturnTypes[i]),
"Return value of type %T not assignable to return type %v", stubbedReturnValues[i], expectedReturnTypes[i])
}
}
}
func (stubbing *ongoingStubbing) ThenPanic(v interface{}) *ongoingStubbing {
stubbing.genericMock.stubWithCallback(
stubbing.MethodName,
stubbing.ParamMatchers,
func([]Param) ReturnValues { panic(v) })
return stubbing
}
func (stubbing *ongoingStubbing) Then(callback func([]Param) ReturnValues) *ongoingStubbing {
stubbing.genericMock.stubWithCallback(
stubbing.MethodName,
stubbing.ParamMatchers,
callback)
return stubbing
}
type InOrderContext struct {
invocationCounter int
lastInvokedMethodName string
lastInvokedMethodParams []Param
}
// Matcher ... it is guaranteed that FailureMessage will always be called after Matches
// so an implementation can save state
type Matcher interface {
Matches(param Param) bool
FailureMessage() string
fmt.Stringer
}
func DumpInvocationsFor(mock Mock) {
fmt.Print(SDumpInvocationsFor(mock))
}
func SDumpInvocationsFor(mock Mock) string {
result := &bytes.Buffer{}
for _, mockedMethod := range GetGenericMockFrom(mock).mockedMethods {
for _, invocation := range mockedMethod.invocations {
fmt.Fprintf(result, "Method invocation: %v (\n", mockedMethod.name)
for _, param := range invocation.params {
fmt.Fprint(result, format.Object(param, 1), ",\n")
}
fmt.Fprintln(result, ")")
}
}
return result.String()
}
// InterceptMockFailures runs a given callback and returns an array of
// failure messages generated by any Pegomock verifications within the callback.
//
// This is accomplished by temporarily replacing the *global* fail handler
// with a fail handler that simply annotates failures. The original fail handler
// is reset when InterceptMockFailures returns.
func InterceptMockFailures(f func()) []string {
originalHandler := GlobalFailHandler
failures := []string{}
RegisterMockFailHandler(func(message string, callerSkip ...int) {
failures = append(failures, message)
})
f()
RegisterMockFailHandler(originalHandler)
return failures
}