@ -9,6 +9,7 @@ package singleflight
import (
import (
"bytes"
"bytes"
"context"
"errors"
"errors"
"fmt"
"fmt"
"os"
"os"
@ -321,3 +322,155 @@ func TestPanicDoSharedByDoChan(t *testing.T) {
t . Errorf ( "Test subprocess failed, but the crash isn't caused by panicking in Do" )
t . Errorf ( "Test subprocess failed, but the crash isn't caused by panicking in Do" )
}
}
}
}
func TestDoChanContext ( t * testing . T ) {
t . Run ( "Basic" , func ( t * testing . T ) {
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
var g Group [ string , int ]
ch := g . DoChanContext ( ctx , "key" , func ( _ context . Context ) ( int , error ) {
return 1 , nil
} )
ret := <- ch
assertOKResult ( t , ret , 1 )
} )
t . Run ( "DoesNotPropagateValues" , func ( t * testing . T ) {
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
key := new ( int )
const value = "hello world"
ctx = context . WithValue ( ctx , key , value )
var g Group [ string , int ]
ch := g . DoChanContext ( ctx , "foobar" , func ( ctx context . Context ) ( int , error ) {
if _ , ok := ctx . Value ( key ) . ( string ) ; ok {
t . Error ( "expected no value, but was present in context" )
}
return 1 , nil
} )
ret := <- ch
assertOKResult ( t , ret , 1 )
} )
t . Run ( "NoCancelWhenWaiters" , func ( t * testing . T ) {
testCtx , testCancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
defer testCancel ( )
trigger := make ( chan struct { } )
ctx1 , cancel1 := context . WithCancel ( context . Background ( ) )
defer cancel1 ( )
ctx2 , cancel2 := context . WithCancel ( context . Background ( ) )
defer cancel2 ( )
fn := func ( ctx context . Context ) ( int , error ) {
select {
case <- ctx . Done ( ) :
return 0 , ctx . Err ( )
case <- trigger :
return 1234 , nil
}
}
// Create two waiters, then cancel the first before we trigger
// the function to return a value. This shouldn't result in a
// context canceled error.
var g Group [ string , int ]
ch1 := g . DoChanContext ( ctx1 , "key" , fn )
ch2 := g . DoChanContext ( ctx2 , "key" , fn )
cancel1 ( )
// The first channel, now that it's canceled, should return a
// context canceled error.
select {
case res := <- ch1 :
if ! errors . Is ( res . Err , context . Canceled ) {
t . Errorf ( "unexpected error; got %v, want context.Canceled" , res . Err )
}
case <- testCtx . Done ( ) :
t . Fatal ( "test timed out" )
}
// Actually return
close ( trigger )
res := <- ch2
assertOKResult ( t , res , 1234 )
} )
t . Run ( "AllCancel" , func ( t * testing . T ) {
for _ , n := range [ ] int { 1 , 2 , 10 , 20 } {
t . Run ( fmt . Sprintf ( "NumWaiters=%d" , n ) , func ( t * testing . T ) {
testCtx , testCancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
defer testCancel ( )
trigger := make ( chan struct { } )
defer close ( trigger )
fn := func ( ctx context . Context ) ( int , error ) {
select {
case <- ctx . Done ( ) :
return 0 , ctx . Err ( )
case <- trigger :
t . Error ( "unexpected trigger; want all callers to cancel" )
return 0 , errors . New ( "unexpected trigger" )
}
}
// Launch N goroutines that all wait on the same key.
var (
g Group [ string , int ]
chs [ ] <- chan Result [ int ]
cancels [ ] context . CancelFunc
)
for i := range n {
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
cancels = append ( cancels , cancel )
ch := g . DoChanContext ( ctx , "key" , fn )
chs = append ( chs , ch )
// Every third goroutine should cancel
// immediately, which better tests the
// cancel logic.
if i % 3 == 0 {
cancel ( )
}
}
// Now that everything is waiting, cancel all the contexts.
for _ , cancel := range cancels {
cancel ( )
}
// Wait for a result from each channel. They
// should all return an error showing a context
// cancel.
for _ , ch := range chs {
select {
case res := <- ch :
if ! errors . Is ( res . Err , context . Canceled ) {
t . Errorf ( "unexpected error; got %v, want context.Canceled" , res . Err )
}
case <- testCtx . Done ( ) :
t . Fatal ( "test timed out" )
}
}
} )
}
} )
}
func assertOKResult [ V comparable ] ( t testing . TB , res Result [ V ] , want V ) {
if res . Err != nil {
t . Fatalf ( "unexpected error: %v" , res . Err )
}
if res . Val != want {
t . Fatalf ( "unexpected value; got %v, want %v" , res . Val , want )
}
}