Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func TestTrap(t *testing.T) {
count++
})
call := trap.MustWait(ctx)
call.Release()
call.MustRelease(ctx)
if call.Duration != time.Hour {
t.Fatal("wrong duration")
}
Expand Down Expand Up @@ -268,15 +268,15 @@ func TestTrap2(t *testing.T) {
}(mClock)

// start
trap.MustWait(ctx).Release()
trap.MustWait(ctx).MustRelease(ctx)
// phase 1
call := trap.MustWait(ctx)
mClock.Advance(3*time.Second).MustWait(ctx)
call.Release()
call.MustRelease(ctx)
// phase 2
call = trap.MustWait(ctx)
mClock.Advance(5*time.Second).MustWait(ctx)
call.Release()
call.MustRelease(ctx)

<-done
// Now logs contains []string{"Phase 1 took 3s", "Phase 2 took 5s"}
Expand All @@ -302,7 +302,7 @@ go func(){
}()
call := trap.MustWait(ctx)
mClock.Advance(time.Second).MustWait(ctx)
call.Release()
call.MustRelease(ctx)
// call.Tags contains []string{"foo", "bar"}

gotFoo := <-foo // 1s after start
Expand Down Expand Up @@ -478,8 +478,8 @@ func TestTicker(t *testing.T) {
trap := mClock.Trap().TickerFunc()
defer trap.Close() // stop trapping at end
go runMyTicker(mClock) // async calls TickerFunc()
call := trap.Wait(context.Background()) // waits for a call and blocks its return
call.Release() // allow the TickerFunc() call to return
call := trap.MustWait(context.Background()) // waits for a call and blocks its return
call.MustRelease(ctx) // allow the TickerFunc() call to return
// optionally check the duration using call.Duration
// Move the clock forward 1 tick
mClock.Advance(time.Second).MustWait(context.Background())
Expand Down Expand Up @@ -527,9 +527,9 @@ go func(clock quartz.Clock) {
measurement = clock.Since(start)
}(mClock)

c := trap.Wait(ctx)
c := trap.MustWait(ctx)
mClock.Advance(5*time.Second)
c.Release()
c.MustRelease(ctx)
```

We wait until we trap the `clock.Since()` call, which implies that `clock.Now()` has completed, then
Expand Down Expand Up @@ -617,10 +617,10 @@ func TestInactivityTimer_Late(t *testing.T) {

// Trigger the AfterFunc
w := mClock.Advance(10*time.Minute)
c := trap.Wait(ctx)
c := trap.MustWait(ctx)
// Advance the clock a few ms to simulate a busy system
mClock.Advance(3*time.Millisecond)
c.Release() // Until() returns
c.MustRelease(ctx) // Until() returns
w.MustWait(ctx) // Wait for the AfterFunc to wrap up

// Assert that the timeoutLocked() function was called
Expand Down
6 changes: 3 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestExampleTickerFunc(t *testing.T) {
// it's good practice to release calls before any possible t.Fatal() calls
// so that we don't leave dangling goroutines waiting for the call to be
// released.
call.Release()
call.MustRelease(ctx)
if call.Duration != time.Hour {
t.Fatal("unexpected duration")
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestExampleLatencyMeasurer(t *testing.T) {
w := mClock.Advance(10 * time.Second) // triggers first tick
c := trap.MustWait(ctx) // call to Since()
mClock.Advance(33 * time.Millisecond)
c.Release()
c.MustRelease(ctx)
w.MustWait(ctx)

if l := lm.LastLatency(); l != 33*time.Millisecond {
Expand All @@ -133,7 +133,7 @@ func TestExampleLatencyMeasurer(t *testing.T) {
d, w2 := mClock.AdvanceNext()
c = trap.MustWait(ctx)
mClock.Advance(17 * time.Millisecond)
c.Release()
c.MustRelease(ctx)
w2.MustWait(ctx)

expectedD := 10*time.Second - 33*time.Millisecond
Expand Down
33 changes: 29 additions & 4 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ func (c clockFunction) String() string {
case clockFunctionUntil:
return "Until"
default:
return "?????"
return fmt.Sprintf("Unknown clockFunction(%d)", c)
}
}

Expand Down Expand Up @@ -633,7 +633,7 @@ func (a *apiCall) String() string {
case clockFunctionUntil:
return fmt.Sprintf("Until(%s, %v)", a.Time, a.Tags)
default:
return "?????"
return fmt.Sprintf("Unknown clockFunction(%d)", a.fn)
}
}

Expand All @@ -643,14 +643,38 @@ type Call struct {
Duration time.Duration
Tags []string

tb testing.TB
apiCall *apiCall
trap *Trap
}

func (c *Call) Release() {
// Release the call and wait for it to complete. If the provided context expires before the call completes, it returns
// an error.
//
// IMPORTANT: If a call is trapped by more than one trap, they all must release the call before it can complete, and
// they must do so from different goroutines.
func (c *Call) Release(ctx context.Context) error {
c.apiCall.releases.Done()
<-c.apiCall.complete
select {
case <-ctx.Done():
return fmt.Errorf("timed out waiting for release; did more than one trap capture the call?: %w", ctx.Err())
case <-c.apiCall.complete:
// OK
}
c.trap.callReleased()
return nil
}

// MustRelease releases the call and waits for it to complete. If the provided context expires before the call
// completes, it fails the test.
//
// IMPORTANT: If a call is trapped by more than one trap, they all must release the call before it can complete, and
// they must do so from different goroutines.
func (c *Call) MustRelease(ctx context.Context) {
if err := c.Release(ctx); err != nil {
c.tb.Helper()
c.tb.Fatal(err.Error())
}
}

func withTime(t time.Time) callArg {
Expand Down Expand Up @@ -745,6 +769,7 @@ func (t *Trap) Wait(ctx context.Context) (*Call, error) {
Tags: a.Tags,
apiCall: a,
trap: t,
tb: t.mock.tb,
}
t.mu.Lock()
defer t.mu.Unlock()
Expand Down
38 changes: 30 additions & 8 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestTimer_NegativeDuration(t *testing.T) {
timers <- mClock.NewTimer(-time.Second)
}()
c := trap.MustWait(ctx)
c.Release()
c.MustRelease(ctx)
// trap returns the actual passed value
if c.Duration != -time.Second {
t.Fatalf("expected -time.Second, got: %v", c.Duration)
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestAfterFunc_NegativeDuration(t *testing.T) {
})
}()
c := trap.MustWait(ctx)
c.Release()
c.MustRelease(ctx)
// trap returns the actual passed value
if c.Duration != -time.Second {
t.Fatalf("expected -time.Second, got: %v", c.Duration)
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestNewTicker(t *testing.T) {
tickers <- mClock.NewTicker(time.Hour, "new")
}()
c := trapNT.MustWait(ctx)
c.Release()
c.MustRelease(ctx)
if c.Duration != time.Hour {
t.Fatalf("expected time.Hour, got: %v", c.Duration)
}
Expand All @@ -123,7 +123,7 @@ func TestNewTicker(t *testing.T) {
go tkr.Reset(time.Minute, "reset")
c = trapReset.MustWait(ctx)
mClock.Advance(time.Second).MustWait(ctx)
c.Release()
c.MustRelease(ctx)
if c.Duration != time.Minute {
t.Fatalf("expected time.Minute, got: %v", c.Duration)
}
Expand All @@ -142,7 +142,7 @@ func TestNewTicker(t *testing.T) {
}

go tkr.Stop("stop")
trapStop.MustWait(ctx).Release()
trapStop.MustWait(ctx).MustRelease(ctx)
mClock.Advance(time.Hour).MustWait(ctx)
select {
case <-tkr.C:
Expand All @@ -153,7 +153,7 @@ func TestNewTicker(t *testing.T) {

// Resetting after stop
go tkr.Reset(time.Minute, "reset")
trapReset.MustWait(ctx).Release()
trapReset.MustWait(ctx).MustRelease(ctx)
mClock.Advance(time.Minute).MustWait(ctx)
tTime = mClock.Now()
select {
Expand Down Expand Up @@ -344,11 +344,11 @@ func Test_MultipleTraps(t *testing.T) {
done := make(chan struct{})
go func() {
defer close(done)
c0.Release()
c0.MustRelease(testCtx)
}()
c1 := trap1.MustWait(testCtx)
mClock.Advance(time.Second)
c1.Release()
c1.MustRelease(testCtx)

select {
case <-done:
Expand All @@ -367,6 +367,28 @@ func Test_MultipleTraps(t *testing.T) {
}
}

func Test_MultipleTrapsDeadlock(t *testing.T) {
t.Parallel()
tRunFail(t, func(t testing.TB) {
testCtx, testCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer testCancel()
mClock := quartz.NewMock(t)

trap0 := mClock.Trap().Now("0")
defer trap0.Close()
trap1 := mClock.Trap().Now("1")
defer trap1.Close()

timeCh := make(chan time.Time)
go func() {
timeCh <- mClock.Now("0", "1")
}()

c0 := trap0.MustWait(testCtx)
c0.MustRelease(testCtx) // deadlocks, test failure
})
}

func Test_UnreleasedCalls(t *testing.T) {
t.Parallel()
tRunFail(t, func(t testing.TB) {
Expand Down