From f22e310501b8894a6f8a0b9eee434ad0cf8c05a7 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Wed, 19 Apr 2023 21:14:29 -0400 Subject: [PATCH] fix: duration checks to take into account values can be negative Replace `IsZero` with `IsAboveZero` to help us avoid this mistake again. --- config/caching.go | 4 ++-- config/config.go | 2 +- config/config_test.go | 4 ++-- config/duration.go | 4 ++-- config/duration_test.go | 23 +++++++++++++++++------ resolver/bootstrap.go | 4 ++-- resolver/caching_resolver.go | 6 +++--- 7 files changed, 29 insertions(+), 18 deletions(-) diff --git a/config/caching.go b/config/caching.go index 0393d34f8..c942398cd 100644 --- a/config/caching.go +++ b/config/caching.go @@ -20,7 +20,7 @@ type CachingConfig struct { // IsEnabled implements `config.Configurable`. func (c *CachingConfig) IsEnabled() bool { - return c.MaxCachingTime > 0 + return c.MaxCachingTime.IsAboveZero() } // LogConfig implements `config.Configurable`. @@ -42,7 +42,7 @@ func (c *CachingConfig) LogConfig(logger *logrus.Entry) { func (c *CachingConfig) EnablePrefetch() { const day = Duration(24 * time.Hour) - if c.MaxCachingTime.IsZero() { + if !c.IsEnabled() { // make sure resolver gets enabled c.MaxCachingTime = day } diff --git a/config/config.go b/config/config.go index 1580352fc..d6d5af725 100644 --- a/config/config.go +++ b/config/config.go @@ -298,7 +298,7 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) { logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource) logger.Debugf("strategy = %s", c.Strategy) - if c.RefreshPeriod > 0 { + if c.RefreshPeriod.IsAboveZero() { logger.Infof("refresh = every %s", c.RefreshPeriod) } else { logger.Debug("refresh = disabled") diff --git a/config/config_test.go b/config/config_test.go index 5e2039c72..15851f963 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -785,8 +785,8 @@ func defaultTestFileConfig() { Expect(config.Filtering.QueryTypes).Should(HaveLen(2)) Expect(config.FqdnOnly.Enable).Should(BeTrue()) - Expect(config.Caching.MaxCachingTime.IsZero()).Should(BeTrue()) - Expect(config.Caching.MinCachingTime.IsZero()).Should(BeTrue()) + Expect(config.Caching.MaxCachingTime).Should(BeZero()) + Expect(config.Caching.MinCachingTime).Should(BeZero()) Expect(config.DoHUserAgent).Should(Equal("testBlocky")) Expect(config.MinTLSServeVer).Should(Equal("1.3")) diff --git a/config/duration.go b/config/duration.go index 5ca316acc..c8ef09a76 100644 --- a/config/duration.go +++ b/config/duration.go @@ -14,8 +14,8 @@ func (c Duration) ToDuration() time.Duration { return time.Duration(c) } -func (c Duration) IsZero() bool { - return c.ToDuration() == 0 +func (c Duration) IsAboveZero() bool { + return c.ToDuration() > 0 } func (c Duration) Seconds() float64 { diff --git a/config/duration_test.go b/config/duration_test.go index 882f038f1..047337ac9 100644 --- a/config/duration_test.go +++ b/config/duration_test.go @@ -31,14 +31,25 @@ var _ = Describe("Duration", func() { }) }) - Describe("IsZero", func() { - It("should be true for zero", func() { - Expect(d.IsZero()).Should(BeTrue()) - Expect(Duration(0).IsZero()).Should(BeTrue()) + Describe("IsAboveZero", func() { + It("should be false for zero", func() { + Expect(d.IsAboveZero()).Should(BeFalse()) + Expect(Duration(0).IsAboveZero()).Should(BeFalse()) }) - It("should be false for non-zero", func() { - Expect(Duration(time.Second).IsZero()).Should(BeFalse()) + It("should be false for negative", func() { + Expect(Duration(-1).IsAboveZero()).Should(BeFalse()) + }) + + It("should be true for positive", func() { + Expect(Duration(1).IsAboveZero()).Should(BeTrue()) + }) + }) + + Describe("SecondsU32", func() { + It("should return the seconds", func() { + Expect(Duration(time.Minute).SecondsU32()).Should(Equal(uint32(60))) + Expect(Duration(time.Hour).SecondsU32()).Should(Equal(uint32(3600))) }) }) }) diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index d8d956920..27e2fa8ca 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -78,7 +78,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) { cachingCfg := cfg.Caching cachingCfg.EnablePrefetch() - if cachingCfg.MinCachingTime.IsZero() { + if !cachingCfg.MinCachingTime.IsAboveZero() { // Set a min time in case the user didn't to avoid prefetching too often cachingCfg.MinCachingTime = config.Duration(time.Hour) } @@ -116,7 +116,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) { ctx := context.Background() timeout := cfg.UpstreamTimeout - if timeout.IsZero() { + if timeout.IsAboveZero() { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration()) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 08091b9da..d0e2a2496 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -217,7 +217,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, // put value into cache r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer)) } else if response.Res.Rcode == dns.RcodeNameError { - if r.cfg.CacheTimeNegative > 0 { + if r.cfg.CacheTimeNegative.IsAboveZero() { // put negative cache if result code is NXDOMAIN r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration()) } @@ -244,13 +244,13 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) { for _, a := range answer { // if TTL < mitTTL -> adjust the value, set minTTL - if r.cfg.MinCachingTime > 0 { + if r.cfg.MinCachingTime.IsAboveZero() { if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() { atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32()) } } - if r.cfg.MaxCachingTime > 0 { + if r.cfg.MaxCachingTime.IsAboveZero() { if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() { atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32()) }