Skip to content

Commit

Permalink
refactor(log): store log in context so it"s automatically propagated
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Mar 19, 2024
1 parent d83b743 commit b335887
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 102 deletions.
46 changes: 46 additions & 0 deletions log/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package log

import (
"context"

"github.com/sirupsen/logrus"
)

type ctxKey struct{}

func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) {
ctx = context.WithValue(ctx, ctxKey{}, logger)

return ctx, entryWithCtx(ctx, logger)
}

func FromCtx(ctx context.Context) *logrus.Entry {
logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry)
if !ok {
// Fallback to the global logger
return logrus.NewEntry(Log())
}

// Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context`
return entryWithCtx(ctx, logger)
}

func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry {
loggerCopy := *logger
loggerCopy.Context = ctx

return &loggerCopy
}

func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
logger := FromCtx(ctx)
logger = wrap(logger)

return NewCtx(ctx, logger)
}

func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry {
return e.WithFields(fields)
})
}
23 changes: 16 additions & 7 deletions resolver/blocking_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,20 @@ func NewBlockingResolver(ctx context.Context,
}

func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)

for {
select {
case em := <-r.redisClient.EnabledChannel:
if em != nil {
r.log().Debug("Received state from redis: ", em)
logger.Debug("Received state from redis: ", em)

if em.State {
r.internalEnableBlocking()
} else {
err := r.internalDisableBlocking(ctx, em.Duration, em.Groups)
if err != nil {
r.log().Warn("Blocking couldn"t be disabled:", err)
logger.Warn("Blocking couldn"t be disabled:", err)
}
}
}
Expand Down Expand Up @@ -394,7 +396,7 @@ func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []

// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver")
ctx, logger := r.log(ctx)
groupsToCheck := r.groupsToCheckForClient(request)

if len(groupsToCheck) > 0 {
Expand Down Expand Up @@ -575,7 +577,9 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
}

func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
ctx, logger := r.logWith(ctx, func(logger *logrus.Entry) *logrus.Entry {
return log.WithPrefix(logger, "client_id_cache")
})

var result []net.IP

Expand All @@ -584,7 +588,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog,
Log: logger,
})

if err == nil && resp.Res.Rcode == dns.RcodeSuccess {
Expand All @@ -598,11 +602,16 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifi
result = append(result, v.AAAA)
}
}

prefixedLog.Debugf("resolved IPs '%v' for fq identifier '%s'", result, identifier)
}
}

if len(result) != 0 {
logger.WithFields(logrus.Fields{
"ips": result,
"client_id": identifier,
}).Debug("resolved client IPs")
}

return &result, ttl
}

Expand Down
15 changes: 8 additions & 7 deletions resolver/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/hashicorp/go-multierror"
Expand Down Expand Up @@ -70,13 +69,15 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
dialer: new(net.Dialer),
}

ctx, logger := b.log(ctx)

bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
if err != nil {
return nil, err
}

if len(bootstraped) == 0 {
b.log().Info("bootstrapDns is not configured, will use system resolver")
logger.Info("bootstrapDns is not configured, will use system resolver")

return b, nil
}
Expand Down Expand Up @@ -116,10 +117,9 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model
}

// Add bootstrap prefix to all inner resolver logs
req := *request
req.Log = log.WithPrefix(req.Log, b.Type())
ctx, _ = b.log(ctx)

return b.resolver.Resolve(ctx, &req)
return b.resolver.Resolve(ctx, request)
}

func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
Expand Down Expand Up @@ -168,7 +168,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
}

func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := b.log().WithFields(logrus.Fields{"network": network, "addr": addr})
ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})

host, port, err := net.SplitHostPort(addr)
if err != nil {
Expand Down Expand Up @@ -234,9 +234,10 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.
return []net.IP{ip}, nil
}

ctx, _ = b.log(ctx)

req := model.Request{
Req: util.NewMsgWithQuestion(hostname, qType),
Log: b.log(),
}

rsp, err := b.resolver.Resolve(ctx, &req)
Expand Down
21 changes: 12 additions & 9 deletions resolver/caching_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/0xERR0R/blocky/cache/expirationcache"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/util"
Expand Down Expand Up @@ -107,7 +106,7 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin

func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()
ctx, logger := r.log(ctx)

logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)

Expand All @@ -133,11 +132,13 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string)
}

func (r *CachingResolver) redisSubscriber(ctx context.Context) {
ctx, logger := r.log(ctx)

for {
select {
case rc := <-r.redisClient.CacheChannel:
if rc != nil {
r.log().Debug("Received key from redis: ", rc.Key)
logger.Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false)
}
Expand All @@ -158,7 +159,7 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
// Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")
ctx, logger := r.log(ctx)

if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache")
Expand All @@ -171,7 +172,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain)
logger := logger.WithField("domain", util.Obfuscate(domain))

val, ttl := r.getFromCache(cacheKey)
val, ttl := r.getFromCache(logger, cacheKey)

if val != nil {
logger.Debug("domain is cached")
Expand Down Expand Up @@ -200,7 +201,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (
return response, err
}

func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) {
func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) {
val, ttl := r.resultCache.Get(key)
if val == nil {
return nil, 0
Expand All @@ -210,7 +211,7 @@ func (r *CachingResolver) getFromCache(key string) (*dns.Msg, time.Duration) {

err := res.Unpack(*val)
if err != nil {
r.log().Error("can"t unpack cached entry. Cache malformed?", err)
logger.Error("can"t unpack cached entry. Cache malformed?", err)

return nil, 0
}
Expand Down Expand Up @@ -317,7 +318,9 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
}
}

func (r *CachingResolver) FlushCaches(context.Context) {
r.log().Debug("flush caches")
func (r *CachingResolver) FlushCaches(ctx context.Context) {
_, logger := r.log(ctx)

logger.Debug("flush caches")
r.resultCache.Clear()
}
10 changes: 5 additions & 5 deletions resolver/client_names_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Reques
clientNames := r.getClientNames(ctx, request)

request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
ctx, request.Log = log.CtxWithFields(ctx, logrus.Fields{"client_names": strings.Join(clientNames, "; ")})

return r.next.Resolve(ctx, request)
}
Expand All @@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model
return cpy
}

names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
names := r.resolveClientNames(ctx, ip)

r.cache.Put(ip.String(), &names, time.Hour)

Expand All @@ -111,9 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
}

// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(
ctx context.Context, ip net.IP, logger *logrus.Entry,
) (result []string) {
func (r *ClientNamesResolver) resolveClientNames(ctx context.Context, ip net.IP) (result []string) {
ctx, logger := r.log(ctx)

// try client mapping first
result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 {
Expand Down
5 changes: 2 additions & 3 deletions resolver/conditional_upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"

"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"

Expand Down Expand Up @@ -83,7 +82,7 @@ func (r *ConditionalUpstreamResolver) processRequest(

// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver")
ctx, logger := r.log(ctx)

if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(ctx, request)
Expand All @@ -101,7 +100,7 @@ func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso
req *model.Request,
) (*model.Response, error) {
// internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver")
ctx, logger := r.log(ctx)

req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(ctx, req)
Expand Down
Loading

0 comments on commit b335887

Please sign in to comment.