Skip to content

Commit

Permalink
refactor(util): make LogOnError get the log from a Context
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Mar 19, 2024
1 parent b335887 commit 3fcf379
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 30 deletions.
2 changes: 1 addition & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 63,7 @@ func startServer(_ *cobra.Command, _ []string) error {
select {
case <-signals:
log.Log().Infof("Terminating...")
util.LogOnError("can't stop server: ", srv.Stop(ctx))
util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx))
done <- true

case err := <-errChan:
Expand Down
2 changes: 1 addition & 1 deletion querylog/database_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 128,7 @@ func (d *DatabaseWriter) periodicFlush(ctx context.Context) {
case <-ticker.C:
err := d.doDBWrite()

util.LogOnError("can't write entries to the database: ", err)
util.LogOnError(ctx, "can't write entries to the database: ", err)

case <-ctx.Done():
return
Expand Down
12 changes: 6 additions & 6 deletions resolver/caching_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 125,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
return &packed, r.adjustTTLs(response.Res.Answer)
}
} else {
util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err)
util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err)
}

return nil, 0
Expand All @@ -140,7 140,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) {
if rc != nil {
logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false)
r.putInCache(ctx, rc.Key, rc.Response, ttl, false)
}

case <-ctx.Done():
Expand Down Expand Up @@ -194,7 194,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (

if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
r.putInCache(cacheKey, response, cacheTTL, true)
r.putInCache(ctx, cacheKey, response, cacheTTL, true)
}
}

Expand Down Expand Up @@ -250,16 250,16 @@ func isResponseCacheable(msg *dns.Msg) bool {
return !msg.Truncated && !msg.CheckingDisabled
}

func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, ttl time.Duration,
publish bool,
func (r *CachingResolver) putInCache(
ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool,
) {
respCopy := response.Res.Copy()

// don't cache any EDNS OPT records
util.RemoveEdns0Record(respCopy)

packed, err := respCopy.Pack()
util.LogOnError("error on packing", err)
util.LogOnError(ctx, "error on packing", err)

if err == nil {
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {
Expand Down
20 changes: 14 additions & 6 deletions resolver/parallel_best_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 5,13 @@ import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"sync/atomic"
"time"

"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"

Expand Down Expand Up @@ -161,7 163,7 @@ func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Reque
ctx, cancel := context.WithCancel(ctx)
defer cancel() // abort requests to resolvers that lost the race

resolvers := pickRandom(allResolvers, r.resolverCount)
resolvers := pickRandom(ctx, allResolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers))

for _, resolver := range resolvers {
Expand Down Expand Up @@ -210,7 212,7 @@ func (r *ParallelBestResolver) retryWithDifferent(
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(*r.resolvers.Load(), resolvers)
resolver := weightedRandom(ctx, *r.resolvers.Load(), resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver)

resp, err := resolver.resolve(ctx, request)
Expand All @@ -227,17 229,17 @@ func (r *ParallelBestResolver) retryWithDifferent(
}

// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
func pickRandom(resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
func pickRandom(ctx context.Context, resolvers []*upstreamResolverStatus, resolverCount int) []*upstreamResolverStatus {
chosenResolvers := make([]*upstreamResolverStatus, 0, resolverCount)

for i := 0; i < resolverCount; i {
chosenResolvers = append(chosenResolvers, weightedRandom(resolvers, chosenResolvers))
chosenResolvers = append(chosenResolvers, weightedRandom(ctx, resolvers, chosenResolvers))
}

return chosenResolvers
}

func weightedRandom(in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
func weightedRandom(ctx context.Context, in, excludedResolvers []*upstreamResolverStatus) *upstreamResolverStatus {
const errorWindowInSec = 60

choices := make([]weightedrand.Choice[*upstreamResolverStatus, uint], 0, len(in))
Expand All @@ -262,7 264,13 @@ outer:
}

c, err := weightedrand.NewChooser(choices...)
util.LogOnError("can't choose random weighted resolver: ", err)
if err != nil {
log.FromCtx(ctx).WithError(err).Error("can't choose random weighted resolver, falling back to uniform random")

val := rand.Int() //nolint:gosec // pseudo-randomness is good enough

return choices[val%len(choices)].Item
}

return c.Pick()
}
12 changes: 6 additions & 6 deletions resolver/parallel_best_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 301,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})

It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)

for i := 0; i < 1000; i {
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver
res2 := resolvers[1].resolver
Expect(res1).ShouldNot(Equal(res2))
Expand All @@ -330,7 330,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)

for i := 0; i < 100; i {
resolvers := pickRandom(*sut.resolvers.Load(), parallelBestResolverCount)
resolvers := pickRandom(ctx, *sut.resolvers.Load(), parallelBestResolverCount)
res1 := resolvers[0].resolver.(*UpstreamResolver)
res2 := resolvers[1].resolver.(*UpstreamResolver)
Expect(res1).ShouldNot(Equal(res2))
Expand Down Expand Up @@ -493,12 493,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}
})

It("should use 2 random peeked resolvers, weighted with last error timestamp", func() {
It("should use 2 random peeked resolvers, weighted with last error timestamp", func(ctx context.Context) {
By("all resolvers have same weight for random -> equal distribution", func() {
resolverCount := make(map[Resolver]int)

for i := 0; i < 2000; i {
r := weightedRandom(*sut.resolvers.Load(), nil)
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
resolverCount[r.resolver]
}
for _, v := range resolverCount {
Expand All @@ -517,7 517,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
resolverCount := make(map[*UpstreamResolver]int)

for i := 0; i < 200; i {
r := weightedRandom(*sut.resolvers.Load(), nil)
r := weightedRandom(ctx, *sut.resolvers.Load(), nil)
res := r.resolver.(*UpstreamResolver)

resolverCount[res]
Expand Down
2 changes: 1 addition & 1 deletion resolver/upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 162,7 @@ func (r *httpUpstreamClient) callExternal(
}

defer func() {
util.LogOnError("can't close response body ", httpResponse.Body.Close())
util.LogOnError(ctx, "can't close response body ", httpResponse.Body.Close())
}()

if httpResponse.StatusCode != http.StatusOK {
Expand Down
12 changes: 7 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 426,9 @@ func (s *Server) registerDNSHandlers(ctx context.Context) {
for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", wrappedOnRequest)
handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck)
handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnHealthCheck(ctx, w, m)
})
}
}

Expand Down Expand Up @@ -606,10 608,10 @@ func (s *Server) OnRequest(
m := new(dns.Msg)
m.SetRcode(request, dns.RcodeServerFailure)
err := w.WriteMsg(m)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
} else {
err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
}
}

Expand Down Expand Up @@ -672,13 674,13 @@ func getMaxResponseSize(req *model.Request) int {
}

// OnHealthCheck Handler for docker health check. Just returns OK code without delegating to resolver chain
func (s *Server) OnHealthCheck(w dns.ResponseWriter, request *dns.Msg) {
func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(request)
resp.Rcode = dns.RcodeSuccess

err := w.WriteMsg(resp)
util.LogOnError("can't write message: ", err)
util.LogOnError(ctx, "can't write message: ", err)
}

func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) {
Expand Down
5 changes: 3 additions & 2 deletions util/common.go
Original file line number Diff line number Diff line change
@@ -1,6 1,7 @@
package util

import (
"context"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -154,9 155,9 @@ func IterateValueSorted(in map[string]int, fn func(string, int)) {
}

// LogOnError logs the message only if error is not nil
func LogOnError(message string, err error) {
func LogOnError(ctx context.Context, message string, err error) {
if err != nil {
log.Log().Error(message, err)
log.FromCtx(ctx).Error(message, err)
}
}

Expand Down
5 changes: 3 additions & 2 deletions util/common_test.go
Original file line number Diff line number Diff line change
@@ -1,6 1,7 @@
package util

import (
"context"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -153,11 154,11 @@ var _ = Describe("Common function tests", func() {
Describe("Logging functions", func() {
When("LogOnError is called with error", func() {
err := errors.New("test")
It("should log", func() {
It("should log", func(ctx context.Context) {
hook := test.NewGlobal()
Log().AddHook(hook)
defer hook.Reset()
LogOnError("message ", err)
LogOnError(ctx, "message ", err)
Expect(hook.LastEntry().Message).Should(Equal("message test"))
})
})
Expand Down

0 comments on commit 3fcf379

Please sign in to comment.