Skip to content

Commit 50cca88

Browse files
authored
Add tests to detect race conditions and fix two potential issues (#214)
1 parent ac55a19 commit 50cca88

4 files changed

Lines changed: 265 additions & 3 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ $(LOCALBIN):
2929
mkdir -p $(LOCALBIN)
3030

3131
test: generate fmt vet manifests setup-envtest
32-
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./... -coverprofile cover.out
32+
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test -race ./... -coverprofile cover.out
3333

3434
clean:
3535
rm -rf bin/*

pkg/dns/dns_proxy_handler.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"net"
66
"strings"
7+
"sync"
78
"time"
89

910
"github.com/go-logr/logr"
@@ -22,6 +23,7 @@ const (
2223
)
2324

2425
type DNSProxyHandler struct {
26+
sync.RWMutex
2527
log logr.Logger
2628
udpClient *dnsgo.Client
2729
tcpClient *dnsgo.Client
@@ -111,7 +113,9 @@ func (h *DNSProxyHandler) UpdateDNSServerAddr(addr string) error {
111113
return fmt.Errorf("new DNS server address not valid: %w", err)
112114
}
113115

116+
h.Lock()
114117
h.dnsServerAddr = addr
118+
h.Unlock()
115119
return nil
116120
}
117121

@@ -128,7 +132,11 @@ func (h *DNSProxyHandler) getDataFromDNS(addr net.Addr, request *dnsgo.Msg) (*dn
128132
return nil, fmt.Errorf("failed to determine transport protocol: %s", protocol)
129133
}
130134

131-
response, _, err := client.Exchange(request, h.dnsServerAddr)
135+
h.RLock()
136+
dnsAddr := h.dnsServerAddr
137+
h.RUnlock()
138+
139+
response, _, err := client.Exchange(request, dnsAddr)
132140
if err != nil {
133141
return nil, fmt.Errorf("failed to call target DNS: %w", err)
134142
}

pkg/dns/dnscache.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ func (c *DNSCache) getSetsForRendering(fqdns []firewallv1.FQDNSelector) (result
297297
}
298298

299299
func (c *DNSCache) updateDNSServerAddr(addr string) {
300+
c.Lock()
300301
c.dnsServerAddr = addr
302+
c.Unlock()
301303
}
302304

303305
// getSetNameForFQDN returns FQDN set data
@@ -339,13 +341,16 @@ func (c *DNSCache) loadDataFromDNSServer(fqdns []string) error {
339341
return fmt.Errorf("too many hops, fqdn chain: %s", strings.Join(fqdns, ","))
340342
}
341343
qname := fqdns[len(fqdns)-1]
344+
c.RLock()
345+
dnsAddr := c.dnsServerAddr
346+
c.RUnlock()
342347
cl := new(dnsgo.Client)
343348
for _, t := range []uint16{dnsgo.TypeA, dnsgo.TypeAAAA} {
344349
m := new(dnsgo.Msg)
345350
m.Id = dnsgo.Id()
346351
m.SetQuestion(qname, t)
347352
c.log.V(4).Info("DEBUG dnscache loadDataFromDNSServer function querying DNS", "message", m)
348-
in, _, err := cl.Exchange(m, c.dnsServerAddr)
353+
in, _, err := cl.Exchange(m, dnsAddr)
349354
if err != nil {
350355
return fmt.Errorf("failed to get DNS data about fqdn %s: %w", fqdns[0], err)
351356
}

pkg/dns/dnscache_test.go

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package dns
22

33
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"sync"
48
"testing"
59
"time"
610

711
"github.com/go-logr/logr"
812
"github.com/google/go-cmp/cmp"
13+
"github.com/google/nftables"
914
firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1"
15+
dnsgo "github.com/miekg/dns"
1016
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
17+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1118
)
1219

1320
func Test_GetSetsForFQDN(t *testing.T) {
@@ -226,3 +233,245 @@ func Test_createIPSetFromIPEntry(t *testing.T) {
226233
})
227234
}
228235
}
236+
237+
const (
238+
raceNumGoroutines = 10
239+
raceNumIterations = 100
240+
)
241+
242+
func newTestDNSCache(entries map[string]cacheEntry) *DNSCache {
243+
return &DNSCache{
244+
log: logr.Discard(),
245+
fqdnToEntry: entries,
246+
setNames: make(map[string]struct{}),
247+
dnsServerAddr: "127.0.0.1:53",
248+
ctx: context.Background(),
249+
shootClient: fake.NewClientBuilder().Build(),
250+
ipv4Enabled: true,
251+
ipv6Enabled: true,
252+
}
253+
}
254+
255+
func makeTestRRs(fqdn string, ip string) []dnsgo.RR {
256+
return []dnsgo.RR{
257+
&dnsgo.A{
258+
Hdr: dnsgo.RR_Header{Name: fqdn, Rrtype: dnsgo.TypeA, Ttl: 300},
259+
A: net.ParseIP(ip),
260+
},
261+
}
262+
}
263+
264+
func seedEntries(n int) map[string]cacheEntry {
265+
entries := make(map[string]cacheEntry, n)
266+
for i := range n {
267+
fqdn := fmt.Sprintf("host%d.example.com.", i)
268+
entries[fqdn] = cacheEntry{
269+
IPv4: &iPEntry{
270+
SetName: fmt.Sprintf("set%d", i),
271+
IPs: map[string]time.Time{fmt.Sprintf("10.0.0.%d", i%256): time.Now().Add(5 * time.Minute)},
272+
},
273+
}
274+
}
275+
return entries
276+
}
277+
278+
func TestRace_UpdateAndGetSetsForRendering(t *testing.T) {
279+
cache := newTestDNSCache(seedEntries(5))
280+
fqdns := []firewallv1.FQDNSelector{{MatchPattern: "*.example.com"}}
281+
282+
var wg sync.WaitGroup
283+
start := make(chan struct{})
284+
285+
for i := range raceNumGoroutines {
286+
wg.Add(1)
287+
go func(id int) {
288+
defer wg.Done()
289+
<-start
290+
fqdn := fmt.Sprintf("writer%d.example.com.", id)
291+
for j := range raceNumIterations {
292+
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.1.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
293+
}
294+
}(i)
295+
}
296+
297+
for range raceNumGoroutines {
298+
wg.Go(func() {
299+
<-start
300+
for range raceNumIterations {
301+
cache.getSetsForRendering(fqdns)
302+
}
303+
})
304+
}
305+
306+
close(start)
307+
wg.Wait()
308+
}
309+
310+
func TestRace_UpdateAndGetSetNameForRegex(t *testing.T) {
311+
cache := newTestDNSCache(seedEntries(5))
312+
313+
var wg sync.WaitGroup
314+
start := make(chan struct{})
315+
316+
for i := range raceNumGoroutines {
317+
wg.Add(1)
318+
go func(id int) {
319+
defer wg.Done()
320+
<-start
321+
fqdn := fmt.Sprintf("writer%d.example.com.", id)
322+
for j := range raceNumIterations {
323+
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.2.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
324+
}
325+
}(i)
326+
}
327+
328+
for range raceNumGoroutines {
329+
wg.Go(func() {
330+
<-start
331+
for range raceNumIterations {
332+
cache.getSetNameForRegex(`.*\.example\.com\.`)
333+
}
334+
})
335+
}
336+
337+
close(start)
338+
wg.Wait()
339+
}
340+
341+
func TestRace_UpdateAndGetSetNameForFQDN(t *testing.T) {
342+
cache := newTestDNSCache(seedEntries(5))
343+
344+
var wg sync.WaitGroup
345+
start := make(chan struct{})
346+
347+
for i := range raceNumGoroutines {
348+
wg.Add(1)
349+
go func(id int) {
350+
defer wg.Done()
351+
<-start
352+
fqdn := fmt.Sprintf("host%d.example.com.", id%5)
353+
for j := range raceNumIterations {
354+
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.3.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
355+
}
356+
}(i)
357+
}
358+
359+
for i := range raceNumGoroutines {
360+
wg.Add(1)
361+
go func(id int) {
362+
defer wg.Done()
363+
<-start
364+
fqdn := fmt.Sprintf("host%d.example.com.", id%5)
365+
for range raceNumIterations {
366+
cache.getSetNameForFQDN(fqdn)
367+
}
368+
}(i)
369+
}
370+
371+
close(start)
372+
wg.Wait()
373+
}
374+
375+
func TestRace_UpdateAndWriteStateToConfigmap(t *testing.T) {
376+
cache := newTestDNSCache(seedEntries(5))
377+
378+
var wg sync.WaitGroup
379+
start := make(chan struct{})
380+
381+
for i := range raceNumGoroutines {
382+
wg.Add(1)
383+
go func(id int) {
384+
defer wg.Done()
385+
<-start
386+
fqdn := fmt.Sprintf("writer%d.example.com.", id)
387+
for j := range raceNumIterations {
388+
_ = cache.updateIPEntry(fqdn, makeTestRRs(fqdn, fmt.Sprintf("10.4.%d.%d", id, j%256)), time.Now(), nftables.TypeIPAddr)
389+
}
390+
}(i)
391+
}
392+
393+
for range raceNumGoroutines {
394+
wg.Go(func() {
395+
<-start
396+
for range raceNumIterations {
397+
_ = cache.writeStateToConfigmap()
398+
}
399+
})
400+
}
401+
402+
close(start)
403+
wg.Wait()
404+
}
405+
406+
func TestRace_UpdateDNSServerAddr(t *testing.T) {
407+
cache := newTestDNSCache(seedEntries(1))
408+
409+
var wg sync.WaitGroup
410+
start := make(chan struct{})
411+
412+
for i := range raceNumGoroutines {
413+
wg.Add(1)
414+
go func(id int) {
415+
defer wg.Done()
416+
<-start
417+
for j := range raceNumIterations {
418+
cache.updateDNSServerAddr(fmt.Sprintf("10.0.%d.%d:53", id, j%256))
419+
}
420+
}(i)
421+
}
422+
423+
for range raceNumGoroutines {
424+
wg.Go(func() {
425+
<-start
426+
for range raceNumIterations {
427+
cache.RLock()
428+
_ = cache.dnsServerAddr
429+
cache.RUnlock()
430+
}
431+
})
432+
}
433+
434+
close(start)
435+
wg.Wait()
436+
}
437+
438+
func TestRace_ConcurrentMultipleReaders(t *testing.T) {
439+
cache := newTestDNSCache(seedEntries(10))
440+
fqdns := []firewallv1.FQDNSelector{{MatchPattern: "*.example.com"}}
441+
442+
var wg sync.WaitGroup
443+
start := make(chan struct{})
444+
445+
for range raceNumGoroutines {
446+
wg.Go(func() {
447+
<-start
448+
for range raceNumIterations {
449+
cache.getSetsForRendering(fqdns)
450+
}
451+
})
452+
}
453+
454+
for range raceNumGoroutines {
455+
wg.Go(func() {
456+
<-start
457+
for range raceNumIterations {
458+
cache.getSetNameForRegex(`.*\.example\.com\.`)
459+
}
460+
})
461+
}
462+
463+
for i := range raceNumGoroutines {
464+
wg.Add(1)
465+
go func(id int) {
466+
defer wg.Done()
467+
<-start
468+
fqdn := fmt.Sprintf("host%d.example.com.", id%10)
469+
for range raceNumIterations {
470+
cache.getSetNameForFQDN(fqdn)
471+
}
472+
}(i)
473+
}
474+
475+
close(start)
476+
wg.Wait()
477+
}

0 commit comments

Comments
 (0)