diff --git a/cache/expirationcache/cache_interface.go b/cache/expirationcache/cache_interface.go new file mode 100644 index 000000000..c423df896 --- /dev/null +++ b/cache/expirationcache/cache_interface.go @@ -0,0 +1,17 @@ +package expirationcache + +import "time" + +type ExpiringCache interface { + // Put adds the value to the cache unter the passed key with expiration. If expiration <=0, entry will NOT be cached + Put(key string, val interface{}, expiration time.Duration) + + // Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil + Get(key string) (val interface{}, expiration time.Duration) + + // TotalCount returns the total count of valid (not expired) elements + TotalCount() int + + // Clear removes all cache entries + Clear() +} diff --git a/cache/expirationcache/expiration_cache.go b/cache/expirationcache/expiration_cache.go new file mode 100644 index 000000000..0b46f6373 --- /dev/null +++ b/cache/expirationcache/expiration_cache.go @@ -0,0 +1,171 @@ +package expirationcache + +import ( + "time" + + lru "github.com/hashicorp/golang-lru" +) + +const ( + defaultCleanUpInterval = 10 * time.Second + defaultSize = 10_000 +) + +type element struct { + val interface{} + expiresEpochMs int64 +} + +type ExpiringLRUCache struct { + cleanUpInterval time.Duration + preExpirationFn OnExpirationCallback + lru *lru.Cache +} + +type CacheOption func(c *ExpiringLRUCache) + +func WithCleanUpInterval(d time.Duration) CacheOption { + return func(e *ExpiringLRUCache) { + e.cleanUpInterval = d + } +} + +// OnExpirationCallback will be called just before an element gets expired and will +// be removed from cache. This function can return new value and TTL to leave the +// element in the cache or nil to remove it +type OnExpirationCallback func(key string) (val interface{}, ttl time.Duration) + +func WithOnExpiredFn(fn OnExpirationCallback) CacheOption { + return func(c *ExpiringLRUCache) { + c.preExpirationFn = fn + } +} + +func WithMaxSize(size uint) CacheOption { + return func(c *ExpiringLRUCache) { + if size > 0 { + l, _ := lru.New(int(size)) + c.lru = l + } + } +} + +func NewCache(options ...CacheOption) *ExpiringLRUCache { + l, _ := lru.New(defaultSize) + c := &ExpiringLRUCache{ + cleanUpInterval: defaultCleanUpInterval, + preExpirationFn: func(key string) (val interface{}, ttl time.Duration) { + return nil, 0 + }, + lru: l, + } + + for _, opt := range options { + opt(c) + } + + go periodicCleanup(c) + + return c +} + +func periodicCleanup(c *ExpiringLRUCache) { + ticker := time.NewTicker(c.cleanUpInterval) + defer ticker.Stop() + + for { + <-ticker.C + c.cleanUp() + } +} + +func (e *ExpiringLRUCache) cleanUp() { + var expiredKeys []string + + // check for expired items and collect expired keys + for _, k := range e.lru.Keys() { + if v, ok := e.lru.Get(k); ok { + if isExpired(v.(*element)) { + expiredKeys = append(expiredKeys, k.(string)) + } + } + } + + if len(expiredKeys) > 0 { + var keysToDelete []string + + for _, key := range expiredKeys { + newVal, newTTL := e.preExpirationFn(key) + if newVal != nil { + e.Put(key, newVal, newTTL) + } else { + keysToDelete = append(keysToDelete, key) + } + } + + for _, key := range keysToDelete { + e.lru.Remove(key) + } + } +} + +func (e *ExpiringLRUCache) Put(key string, val interface{}, ttl time.Duration) { + if ttl <= 0 { + // entry should be considered as already expired + return + } + + expiresEpochMs := time.Now().UnixMilli() + ttl.Milliseconds() + + el, found := e.lru.Get(key) + if found { + // update existing item + el.(*element).val = val + el.(*element).expiresEpochMs = expiresEpochMs + } else { + // add new item + e.lru.Add(key, &element{ + val: val, + expiresEpochMs: expiresEpochMs, + }) + } +} + +func (e *ExpiringLRUCache) Get(key string) (val interface{}, ttl time.Duration) { + el, found := e.lru.Get(key) + + if found { + return el.(*element).val, calculateRemainTTL(el.(*element).expiresEpochMs) + } + + return nil, 0 +} + +func isExpired(el *element) bool { + return el.expiresEpochMs > 0 && time.Now().UnixMilli() > el.expiresEpochMs +} + +func calculateRemainTTL(expiresEpoch int64) time.Duration { + now := time.Now().UnixMilli() + if now < expiresEpoch { + return time.Duration(expiresEpoch-now) * time.Millisecond + } + + return 0 +} + +func (e *ExpiringLRUCache) TotalCount() (count int) { + for _, k := range e.lru.Keys() { + if v, ok := e.lru.Get(k); ok { + if !isExpired(v.(*element)) { + count++ + } + } + } + + return count +} + +func (e *ExpiringLRUCache) Clear() { + e.lru.Purge() +} diff --git a/cache/expirationcache/expiration_cache_suite_test.go b/cache/expirationcache/expiration_cache_suite_test.go new file mode 100644 index 000000000..515a3de74 --- /dev/null +++ b/cache/expirationcache/expiration_cache_suite_test.go @@ -0,0 +1,16 @@ +package expirationcache_test + +import ( + "testing" + + . "github.com/0xERR0R/blocky/log" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestCache(t *testing.T) { + ConfigureLogger(LevelFatal, FormatTypeText, true) + RegisterFailHandler(Fail) + RunSpecs(t, "Expiration cache suite") +} diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go new file mode 100644 index 000000000..9ce20030c --- /dev/null +++ b/cache/expirationcache/expiration_cache_test.go @@ -0,0 +1,178 @@ +package expirationcache + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Expiration cache", func() { + Describe("Basic operations", func() { + When("string cache was created", func() { + + It("Initial cache should be empty", func() { + cache := NewCache() + Expect(cache.TotalCount()).Should(Equal(0)) + }) + It("Initial cache should not contain any elements", func() { + cache := NewCache() + val, expiration := cache.Get("key1") + Expect(val).Should(BeNil()) + Expect(expiration).Should(Equal(time.Duration(0))) + }) + }) + When("Put new value with positive TTL", func() { + It("Should return the value before element expires", func() { + cache := NewCache(WithCleanUpInterval(100 * time.Millisecond)) + cache.Put("key1", "val1", 50*time.Millisecond) + val, expiration := cache.Get("key1") + Expect(val).Should(Equal("val1")) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 50)) + + Expect(cache.TotalCount()).Should(Equal(1)) + }) + It("Should return nil after expiration", func() { + cache := NewCache(WithCleanUpInterval(100 * time.Millisecond)) + cache.Put("key1", "val1", 50*time.Millisecond) + + // wait for expiration + Eventually(func(g Gomega) { + val, ttl := cache.Get("key1") + g.Expect(val).Should(Equal("val1")) + g.Expect(ttl.Milliseconds()).Should(BeNumerically("==", 0)) + }, "60ms").Should(Succeed()) + + Expect(cache.TotalCount()).Should(Equal(0)) + // internal map has still the expired item + Expect(cache.lru.Len()).Should(Equal(1)) + + // wait for cleanup run + Eventually(func() int { + return cache.lru.Len() + }, "100ms").Should(Equal(0)) + }) + }) + When("Put new value without expiration", func() { + It("Should not cache the value", func() { + cache := NewCache(WithCleanUpInterval(50 * time.Millisecond)) + cache.Put("key1", "val1", 0) + val, expiration := cache.Get("key1") + Expect(val).Should(BeNil()) + Expect(expiration.Milliseconds()).Should(BeNumerically("==", 0)) + Expect(cache.TotalCount()).Should(Equal(0)) + }) + }) + When("Put updated value", func() { + It("Should return updated value", func() { + cache := NewCache() + cache.Put("key1", "val1", 50*time.Millisecond) + cache.Put("key1", "val2", 200*time.Millisecond) + + val, expiration := cache.Get("key1") + + Expect(val).Should(Equal("val2")) + Expect(expiration.Milliseconds()).Should(BeNumerically(">", 100)) + Expect(expiration.Milliseconds()).Should(BeNumerically("<=", 200)) + Expect(cache.TotalCount()).Should(Equal(1)) + }) + }) + When("Purging after usage", func() { + It("Should be empty after purge", func() { + cache := NewCache() + cache.Put("key1", "val1", time.Second) + + Expect(cache.TotalCount()).Should(Equal(1)) + + cache.Clear() + + Expect(cache.TotalCount()).Should(Equal(0)) + }) + }) + }) + Describe("preExpiration function", func() { + When(" function is defined", func() { + It("should update the value and TTL if function returns values", func() { + fn := func(key string) (val interface{}, ttl time.Duration) { + return "val2", time.Second + } + cache := NewCache(WithOnExpiredFn(fn)) + cache.Put("key1", "val1", 50*time.Millisecond) + + // wait for expiration + Eventually(func(g Gomega) { + val, ttl := cache.Get("key1") + g.Expect(val).Should(Equal("val1")) + g.Expect(ttl.Milliseconds()).Should( + BeNumerically("==", 0)) + }, "150ms").Should(Succeed()) + }) + + It("should update the value and TTL if function returns values on cleanup if element is expired", func() { + fn := func(key string) (val interface{}, ttl time.Duration) { + return "val2", time.Second + } + cache := NewCache(WithOnExpiredFn(fn)) + cache.Put("key1", "val1", time.Millisecond) + + time.Sleep(2 * time.Millisecond) + + // trigger cleanUp manually -> onExpiredFn will be executed, because element is expired + cache.cleanUp() + + // wait for expiration + val, ttl := cache.Get("key1") + Expect(val).Should(Equal("val2")) + Expect(ttl.Milliseconds()).Should(And( + BeNumerically(">", 900)), + BeNumerically("<=", 1000)) + }) + + It("should delete the key if function returns nil", func() { + fn := func(key string) (val interface{}, ttl time.Duration) { + return nil, 0 + } + cache := NewCache(WithCleanUpInterval(100*time.Millisecond), WithOnExpiredFn(fn)) + cache.Put("key1", "val1", 50*time.Millisecond) + + Eventually(func() (interface{}, time.Duration) { + return cache.Get("key1") + }, "200ms").Should(BeNil()) + }) + + }) + }) + Describe("LRU behaviour", func() { + When("Defined max size is reached", func() { + It("should remove old elements", func() { + cache := NewCache(WithMaxSize(3)) + + cache.Put("key1", "val1", time.Second) + cache.Put("key2", "val2", time.Second) + cache.Put("key3", "val3", time.Second) + cache.Put("key4", "val4", time.Second) + + Expect(cache.TotalCount()).Should(Equal(3)) + + // key1 was removed + Expect(cache.Get("key1")).Should(BeNil()) + // key2,3,4 still in the cache + Expect(cache.lru.Contains("key2")).Should(BeTrue()) + Expect(cache.lru.Contains("key3")).Should(BeTrue()) + Expect(cache.lru.Contains("key4")).Should(BeTrue()) + + // now get key2 to increase usage count + _, _ = cache.Get("key2") + + // put key5 + cache.Put("key5", "val5", time.Second) + + // now key3 should be removed + Expect(cache.lru.Contains("key2")).Should(BeTrue()) + Expect(cache.lru.Contains("key3")).Should(BeFalse()) + Expect(cache.lru.Contains("key4")).Should(BeTrue()) + Expect(cache.lru.Contains("key5")).Should(BeTrue()) + }) + }) + }) +}) diff --git a/cache/stringcache/string_cache_suite_test.go b/cache/stringcache/string_cache_suite_test.go new file mode 100644 index 000000000..80c7b4ad7 --- /dev/null +++ b/cache/stringcache/string_cache_suite_test.go @@ -0,0 +1,16 @@ +package stringcache_test + +import ( + "testing" + + . "github.com/0xERR0R/blocky/log" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestCache(t *testing.T) { + ConfigureLogger(LevelFatal, FormatTypeText, true) + RegisterFailHandler(Fail) + RunSpecs(t, "String cache suite") +} diff --git a/lists/caches.go b/cache/stringcache/string_caches.go similarity index 65% rename from lists/caches.go rename to cache/stringcache/string_caches.go index 581382885..2f579b8be 100644 --- a/lists/caches.go +++ b/cache/stringcache/string_caches.go @@ -1,4 +1,4 @@ -package lists +package stringcache import ( "regexp" @@ -10,19 +10,19 @@ import ( "github.com/0xERR0R/blocky/util" ) -type cache interface { - elementCount() int - contains(searchString string) bool +type StringCache interface { + ElementCount() int + Contains(searchString string) bool } -type cacheFactory interface { - addEntry(entry string) - create() cache +type CacheFactory interface { + AddEntry(entry string) + Create() StringCache } type stringCache map[int]string -func (cache stringCache) elementCount() int { +func (cache stringCache) ElementCount() int { count := 0 for k, v := range cache { @@ -32,7 +32,7 @@ func (cache stringCache) elementCount() int { return count } -func (cache stringCache) contains(searchString string) bool { +func (cache stringCache) Contains(searchString string) bool { searchLen := len(searchString) if searchLen == 0 { return false @@ -56,7 +56,7 @@ type stringCacheFactory struct { tmp map[int]*strings.Builder } -func newStringCacheFactory() cacheFactory { +func newStringCacheFactory() CacheFactory { return &stringCacheFactory{ cache: make(stringCache), // temporary map to remove duplicates @@ -65,7 +65,7 @@ func newStringCacheFactory() cacheFactory { } } -func (s *stringCacheFactory) addEntry(entry string) { +func (s *stringCacheFactory) AddEntry(entry string) { if _, value := s.keys[entry]; !value { s.keys[entry] = struct{}{} if s.tmp[len(entry)] == nil { @@ -76,7 +76,7 @@ func (s *stringCacheFactory) addEntry(entry string) { } } -func (s *stringCacheFactory) create() cache { +func (s *stringCacheFactory) Create() StringCache { for k, v := range s.tmp { chunks := util.Chunks(v.String(), k) sort.Strings(chunks) @@ -91,11 +91,11 @@ func (s *stringCacheFactory) create() cache { type regexCache []*regexp.Regexp -func (cache regexCache) elementCount() int { +func (cache regexCache) ElementCount() int { return len(cache) } -func (cache regexCache) contains(searchString string) bool { +func (cache regexCache) Contains(searchString string) bool { for _, regex := range cache { if regex.MatchString(searchString) { log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString) @@ -110,7 +110,7 @@ type regexCacheFactory struct { cache regexCache } -func (r *regexCacheFactory) addEntry(entry string) { +func (r *regexCacheFactory) AddEntry(entry string) { compile, err := regexp.Compile(entry) if err != nil { log.Log().Warnf("invalid regex '%s'", entry) @@ -119,32 +119,32 @@ func (r *regexCacheFactory) addEntry(entry string) { } } -func (r *regexCacheFactory) create() cache { +func (r *regexCacheFactory) Create() StringCache { return r.cache } -func newRegexCacheFactory() cacheFactory { +func newRegexCacheFactory() CacheFactory { return ®exCacheFactory{ cache: make(regexCache, 0), } } type chainedCache struct { - caches []cache + caches []StringCache } -func (cache chainedCache) elementCount() int { +func (cache chainedCache) ElementCount() int { sum := 0 for _, c := range cache.caches { - sum += c.elementCount() + sum += c.ElementCount() } return sum } -func (cache chainedCache) contains(searchString string) bool { +func (cache chainedCache) Contains(searchString string) bool { for _, c := range cache.caches { - if c.contains(searchString) { + if c.Contains(searchString) { return true } } @@ -153,28 +153,28 @@ func (cache chainedCache) contains(searchString string) bool { } type chainedCacheFactory struct { - stringCacheFactory cacheFactory - regexCacheFactory cacheFactory + stringCacheFactory CacheFactory + regexCacheFactory CacheFactory } var regexPattern = regexp.MustCompile("^/.*/$") -func (r *chainedCacheFactory) addEntry(entry string) { +func (r *chainedCacheFactory) AddEntry(entry string) { if regexPattern.MatchString(entry) { entry = strings.TrimSpace(strings.Trim(entry, "/")) - r.regexCacheFactory.addEntry(entry) + r.regexCacheFactory.AddEntry(entry) } else { - r.stringCacheFactory.addEntry(entry) + r.stringCacheFactory.AddEntry(entry) } } -func (r *chainedCacheFactory) create() cache { +func (r *chainedCacheFactory) Create() StringCache { return &chainedCache{ - caches: []cache{r.stringCacheFactory.create(), r.regexCacheFactory.create()}, + caches: []StringCache{r.stringCacheFactory.Create(), r.regexCacheFactory.Create()}, } } -func newChainedCacheFactory() cacheFactory { +func NewChainedCacheFactory() CacheFactory { return &chainedCacheFactory{ stringCacheFactory: newStringCacheFactory(), regexCacheFactory: newRegexCacheFactory(), diff --git a/cache/stringcache/string_caches_test.go b/cache/stringcache/string_caches_test.go new file mode 100644 index 000000000..ec6ac3cd1 --- /dev/null +++ b/cache/stringcache/string_caches_test.go @@ -0,0 +1,80 @@ +package stringcache + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Caches", func() { + Describe("String StringCache", func() { + When("string StringCache was created", func() { + factory := newStringCacheFactory() + factory.AddEntry("google.com") + factory.AddEntry("apple.com") + cache := factory.Create() + It("should match if StringCache Contains string", func() { + Expect(cache.Contains("apple.com")).Should(BeTrue()) + Expect(cache.Contains("google.com")).Should(BeTrue()) + Expect(cache.Contains("www.google.com")).Should(BeFalse()) + Expect(cache.Contains("")).Should(BeFalse()) + }) + It("should return correct element count", func() { + Expect(cache.ElementCount()).Should(Equal(2)) + }) + }) + }) + + Describe("Regex StringCache", func() { + When("regex StringCache was created", func() { + factory := newRegexCacheFactory() + factory.AddEntry(".*google.com") + factory.AddEntry("^apple\\.(de|com)$") + factory.AddEntry("amazon") + // this is not a regex, will be ignored + factory.AddEntry("(wrongRegex") + cache := factory.Create() + It("should match if one regex in StringCache matches string", func() { + Expect(cache.Contains("google.com")).Should(BeTrue()) + Expect(cache.Contains("google.coma")).Should(BeTrue()) + Expect(cache.Contains("agoogle.com")).Should(BeTrue()) + Expect(cache.Contains("www.google.com")).Should(BeTrue()) + Expect(cache.Contains("apple.com")).Should(BeTrue()) + Expect(cache.Contains("apple.de")).Should(BeTrue()) + Expect(cache.Contains("apple.it")).Should(BeFalse()) + Expect(cache.Contains("www.apple.com")).Should(BeFalse()) + Expect(cache.Contains("applecom")).Should(BeFalse()) + Expect(cache.Contains("www.amazon.com")).Should(BeTrue()) + Expect(cache.Contains("amazon.com")).Should(BeTrue()) + Expect(cache.Contains("myamazon.com")).Should(BeTrue()) + }) + It("should return correct element count", func() { + Expect(cache.ElementCount()).Should(Equal(3)) + }) + }) + }) + + Describe("Chained StringCache", func() { + When("chained StringCache was created", func() { + factory := NewChainedCacheFactory() + factory.AddEntry("/.*google.com/") + factory.AddEntry("/^apple\\.(de|com)$/") + factory.AddEntry("amazon.com") + cache := factory.Create() + It("should match if one regex in StringCache matches string", func() { + Expect(cache.Contains("google.com")).Should(BeTrue()) + Expect(cache.Contains("google.coma")).Should(BeTrue()) + Expect(cache.Contains("agoogle.com")).Should(BeTrue()) + Expect(cache.Contains("www.google.com")).Should(BeTrue()) + Expect(cache.Contains("apple.com")).Should(BeTrue()) + Expect(cache.Contains("amazon.com")).Should(BeTrue()) + Expect(cache.Contains("apple.de")).Should(BeTrue()) + Expect(cache.Contains("www.apple.com")).Should(BeFalse()) + Expect(cache.Contains("applecom")).Should(BeFalse()) + }) + It("should return correct element count", func() { + Expect(cache.ElementCount()).Should(Equal(3)) + }) + }) + }) + +}) diff --git a/go.mod b/go.mod index 86cd31c96..7d6d119c6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/0xERR0R/blocky go 1.17 require ( - github.com/0xERR0R/go-cache v1.6.0 + github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 github.com/alicebob/miniredis/v2 v2.18.0 github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef github.com/avast/retry-go/v4 v4.0.2 @@ -14,6 +14,7 @@ require ( github.com/google/uuid v1.3.0 github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b github.com/hashicorp/go-multierror v1.1.1 + github.com/hashicorp/golang-lru v0.5.4 github.com/miekg/dns v1.1.45 github.com/mroth/weightedrand v0.4.1 github.com/onsi/ginkgo v1.16.5 diff --git a/go.sum b/go.sum index cd4315cba..01c80cd01 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,6 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/0xERR0R/go-cache v1.6.0 h1:/WFYV/4v0XuOT3MxHNp8zL6MWHHzZt+x2Px0JLtLn2Y= -github.com/0xERR0R/go-cache v1.6.0/go.mod h1:ngsTQjvR6Aq/zvQI88HsfVJwIz9M78r0xNZi1QYY3jE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= @@ -263,6 +261,7 @@ github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= diff --git a/lists/caches_test.go b/lists/caches_test.go deleted file mode 100644 index 9aa22d46f..000000000 --- a/lists/caches_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package lists - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Caches", func() { - Describe("String cache", func() { - When("string cache was created", func() { - factory := newStringCacheFactory() - factory.addEntry("google.com") - factory.addEntry("apple.com") - cache := factory.create() - It("should match if cache contains string", func() { - Expect(cache.contains("apple.com")).Should(BeTrue()) - Expect(cache.contains("google.com")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeFalse()) - }) - It("should return correct element count", func() { - Expect(cache.elementCount()).Should(Equal(2)) - }) - }) - }) - - Describe("Regex cache", func() { - When("regex cache was created", func() { - factory := newRegexCacheFactory() - factory.addEntry(".*google.com") - factory.addEntry("^apple\\.(de|com)$") - factory.addEntry("amazon") - // this is not a regex, will be ignored - factory.addEntry("(wrongRegex") - cache := factory.create() - It("should match if one regex in cache matches string", func() { - Expect(cache.contains("google.com")).Should(BeTrue()) - Expect(cache.contains("google.coma")).Should(BeTrue()) - Expect(cache.contains("agoogle.com")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeTrue()) - Expect(cache.contains("apple.com")).Should(BeTrue()) - Expect(cache.contains("apple.de")).Should(BeTrue()) - Expect(cache.contains("apple.it")).Should(BeFalse()) - Expect(cache.contains("www.apple.com")).Should(BeFalse()) - Expect(cache.contains("applecom")).Should(BeFalse()) - Expect(cache.contains("www.amazon.com")).Should(BeTrue()) - Expect(cache.contains("amazon.com")).Should(BeTrue()) - Expect(cache.contains("myamazon.com")).Should(BeTrue()) - }) - It("should return correct element count", func() { - Expect(cache.elementCount()).Should(Equal(3)) - }) - }) - }) - - Describe("Chained cache", func() { - When("chained cache was created", func() { - factory := newChainedCacheFactory() - factory.addEntry("/.*google.com/") - factory.addEntry("/^apple\\.(de|com)$/") - factory.addEntry("amazon.com") - cache := factory.create() - It("should match if one regex in cache matches string", func() { - Expect(cache.contains("google.com")).Should(BeTrue()) - Expect(cache.contains("google.coma")).Should(BeTrue()) - Expect(cache.contains("agoogle.com")).Should(BeTrue()) - Expect(cache.contains("www.google.com")).Should(BeTrue()) - Expect(cache.contains("apple.com")).Should(BeTrue()) - Expect(cache.contains("amazon.com")).Should(BeTrue()) - Expect(cache.contains("apple.de")).Should(BeTrue()) - Expect(cache.contains("www.apple.com")).Should(BeFalse()) - Expect(cache.contains("applecom")).Should(BeFalse()) - }) - It("should return correct element count", func() { - Expect(cache.elementCount()).Should(Equal(3)) - }) - }) - }) - -}) diff --git a/lists/list_cache.go b/lists/list_cache.go index 36569b0d1..fd5683c11 100644 --- a/lists/list_cache.go +++ b/lists/list_cache.go @@ -13,6 +13,8 @@ import ( "sync" "time" + "github.com/0xERR0R/blocky/cache/stringcache" + "github.com/avast/retry-go/v4" "github.com/hako/durafmt" @@ -41,7 +43,7 @@ type Matcher interface { // ListCache generic cache of strings divided in groups type ListCache struct { - groupCaches map[string]cache + groupCaches map[string]stringcache.StringCache lock sync.RWMutex groupToLinks map[string][]string @@ -78,8 +80,8 @@ func (b *ListCache) Configuration() (result []string) { var total int for group, cache := range b.groupCaches { - result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.elementCount())) - total += cache.elementCount() + result = append(result, fmt.Sprintf(" %s: %d entries", group, cache.ElementCount())) + total += cache.ElementCount() } result = append(result, fmt.Sprintf(" TOTAL: %d entries", total)) @@ -90,7 +92,7 @@ func (b *ListCache) Configuration() (result []string) { // NewListCache creates new list instance func NewListCache(t ListCacheType, groupToLinks map[string][]string, refreshPeriod time.Duration, downloadTimeout time.Duration, downloadAttempts int, downloadCooldown time.Duration) (*ListCache, error) { - groupCaches := make(map[string]cache) + groupCaches := make(map[string]stringcache.StringCache) b := &ListCache{ groupToLinks: groupToLinks, @@ -133,7 +135,7 @@ type groupCache struct { } // downloads and reads files with domain names and creates cache for them -func (b *ListCache) createCacheForGroup(links []string) (cache, error) { +func (b *ListCache) createCacheForGroup(links []string) (stringcache.StringCache, error) { var wg sync.WaitGroup var err error @@ -148,7 +150,7 @@ func (b *ListCache) createCacheForGroup(links []string) (cache, error) { wg.Wait() - factory := newChainedCacheFactory() + factory := stringcache.NewChainedCacheFactory() Loop: for { @@ -161,7 +163,7 @@ Loop: return nil, err } for _, entry := range res.cache { - factory.addEntry(entry) + factory.AddEntry(entry) } default: close(c) @@ -169,7 +171,7 @@ Loop: } } - return factory.create(), err + return factory.Create(), err } // Match matches passed domain name against cached list entries @@ -178,7 +180,7 @@ func (b *ListCache) Match(domain string, groupsToCheck []string) (found bool, gr defer b.lock.RUnlock() for _, g := range groupsToCheck { - if c, ok := b.groupCaches[g]; ok && c.contains(domain) { + if c, ok := b.groupCaches[g]; ok && c.Contains(domain) { return true, g } } @@ -213,11 +215,11 @@ func (b *ListCache) refresh(init bool) error { } if b.groupCaches[group] != nil { - evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, b.groupCaches[group].elementCount()) + evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, b.groupCaches[group].ElementCount()) logger().WithFields(logrus.Fields{ "group": group, - "total_count": b.groupCaches[group].elementCount(), + "total_count": b.groupCaches[group].ElementCount(), }).Info("group import finished") } } diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 48b8ce3c9..7adbf3fa5 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -6,13 +6,13 @@ import ( "github.com/hako/durafmt" + "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/evt" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/util" - "github.com/0xERR0R/go-cache" "github.com/miekg/dns" "github.com/sirupsen/logrus" ) @@ -23,10 +23,10 @@ type CachingResolver struct { NextResolver minCacheTimeSec, maxCacheTimeSec int cacheTimeNegative time.Duration - resultCache *cache.Cache + resultCache expirationcache.ExpiringCache prefetchExpires time.Duration prefetchThreshold int - prefetchingNameCache *cache.Cache + prefetchingNameCache expirationcache.ExpiringCache redisClient *redis.Client redisEnabled bool } @@ -43,14 +43,11 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedRe minCacheTimeSec: int(time.Duration(cfg.MinCachingTime).Seconds()), maxCacheTimeSec: int(time.Duration(cfg.MaxCachingTime).Seconds()), cacheTimeNegative: time.Duration(cfg.CacheTimeNegative), - resultCache: createQueryResultCache(&cfg), redisClient: redis, redisEnabled: (redis != nil), } - if cfg.Prefetching { - configurePrefetching(c, &cfg) - } + configureCaches(c, &cfg) if c.redisEnabled { setupRedisCacheSubscriber(c) @@ -60,20 +57,22 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedRe return c } -func createQueryResultCache(cfg *config.CachingConfig) *cache.Cache { - return cache.NewWithLRU(15*time.Minute, 15*time.Second, cfg.MaxItemsCount) -} - -func configurePrefetching(c *CachingResolver, cfg *config.CachingConfig) { - c.prefetchExpires = time.Duration(cfg.PrefetchExpires) +func configureCaches(c *CachingResolver, cfg *config.CachingConfig) { + cleanupOption := expirationcache.WithCleanUpInterval(5 * time.Second) + maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount)) - c.prefetchThreshold = cfg.PrefetchThreshold + if cfg.Prefetching { + c.prefetchExpires = time.Duration(cfg.PrefetchExpires) - c.prefetchingNameCache = cache.NewWithLRU(c.prefetchExpires, time.Minute, cfg.PrefetchMaxItemsCount) + c.prefetchThreshold = cfg.PrefetchThreshold - c.resultCache.OnEvicted(func(key string, i interface{}) { - c.onEvicted(key) - }) + c.prefetchingNameCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(time.Minute), + expirationcache.WithMaxSize(uint(cfg.PrefetchMaxItemsCount))) + c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption, + expirationcache.WithOnExpiredFn(c.onExpired)) + } else { + c.resultCache = expirationcache.NewCache(cleanupOption, maxSizeOption) + } } func setupRedisCacheSubscriber(c *CachingResolver) { @@ -91,13 +90,13 @@ func setupRedisCacheSubscriber(c *CachingResolver) { // check if domain was queried > threshold in the time window func (r *CachingResolver) isPrefetchingDomain(cacheKey string) bool { - cnt, found := r.prefetchingNameCache.Get(cacheKey) - return found && cnt.(int) > r.prefetchThreshold + cnt, _ := r.prefetchingNameCache.Get(cacheKey) + return cnt != nil && cnt.(int) > r.prefetchThreshold } -// onEvicted is called if a DNS response in the cache is expired and was removed from cache -func (r *CachingResolver) onEvicted(cacheKey string) { +func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.Duration) { qType, domainName := util.ExtractCacheKey(cacheKey) + logger := logger("caching_resolver") if r.isPrefetchingDomain(cacheKey) { @@ -107,13 +106,16 @@ func (r *CachingResolver) onEvicted(cacheKey string) { response, err := r.next.Resolve(req) if err == nil { - r.putInCache(cacheKey, response, true, r.redisEnabled) - - evt.Bus().Publish(evt.CachingDomainPrefetched, domainName) + if response.Res.Rcode == dns.RcodeSuccess { + evt.Bus().Publish(evt.CachingDomainPrefetched, domainName) + return cacheValue{response.Res.Answer, true}, time.Duration(r.adjustTTLs(response.Res.Answer)) * time.Second + } + } else { + util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) } - - util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) } + + return nil, 0 } // Configuration returns a current resolver configuration @@ -137,19 +139,11 @@ func (r *CachingResolver) Configuration() (result []string) { result = append(result, fmt.Sprintf("prefetchThreshold = %d", r.prefetchThreshold)) } - result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.ItemCount())) + result = append(result, fmt.Sprintf("cache items count = %d", r.resultCache.TotalCount())) return } -func calculateRemainingTTL(expiresAt time.Time) uint32 { - if expiresAt.IsZero() { - return 0 - } - - return uint32(time.Until(expiresAt).Seconds()) -} - // Resolve checks if the current query result is already in the cache and returns it // or delegates to the next resolver //nolint:gocognit,funlen @@ -171,17 +165,13 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo r.trackQueryDomainNameCount(domain, cacheKey, logger) - // can return expired items (if cache cleanup is not executed yet) - val, expiresAt, found := r.resultCache.GetRaw(cacheKey) + val, ttl := r.resultCache.Get(cacheKey) - if found { + if val != nil { logger.Debug("domain is cached") evt.Bus().Publish(evt.CachingResultCacheHit, domain) - // calculate remaining TTL - remainingTTL := calculateRemainingTTL(expiresAt) - v, ok := val.(cacheValue) if ok { if v.prefetch { @@ -192,7 +182,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo // Answer from successful request resp.Answer = v.answer for _, rr := range resp.Answer { - rr.Header().Ttl = remainingTTL + rr.Header().Ttl = uint32(ttl.Seconds()) } return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil @@ -219,14 +209,14 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo func (r *CachingResolver) trackQueryDomainNameCount(domain string, cacheKey string, logger *logrus.Entry) { if r.prefetchingNameCache != nil { var domainCount int - if x, found := r.prefetchingNameCache.Get(cacheKey); found { + if x, _ := r.prefetchingNameCache.Get(cacheKey); x != nil { domainCount = x.(int) } domainCount++ - r.prefetchingNameCache.SetDefault(cacheKey, domainCount) + r.prefetchingNameCache.Put(cacheKey, domainCount, r.prefetchExpires) logger.Debugf("domain '%s' was requested %d times, "+ - "total cache size: %d", util.Obfuscate(domain), domainCount, r.prefetchingNameCache.ItemCount()) - evt.Bus().Publish(evt.CachingDomainsToPrefetchCountChanged, r.prefetchingNameCache.ItemCount()) + "total cache size: %d", util.Obfuscate(domain), domainCount, r.prefetchingNameCache.TotalCount()) + evt.Bus().Publish(evt.CachingDomainsToPrefetchCountChanged, r.prefetchingNameCache.TotalCount()) } } @@ -235,15 +225,15 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, if response.Res.Rcode == dns.RcodeSuccess { // put value into cache - r.resultCache.Set(cacheKey, cacheValue{answer, prefetch}, time.Duration(r.adjustTTLs(answer))*time.Second) + r.resultCache.Put(cacheKey, cacheValue{answer, prefetch}, time.Duration(r.adjustTTLs(answer))*time.Second) } else if response.Res.Rcode == dns.RcodeNameError { if r.cacheTimeNegative > 0 { // put return code if NXDOMAIN - r.resultCache.Set(cacheKey, response.Res.Rcode, r.cacheTimeNegative) + r.resultCache.Put(cacheKey, response.Res.Rcode, r.cacheTimeNegative) } } - evt.Bus().Publish(evt.CachingResultCacheChanged, r.resultCache.ItemCount()) + evt.Bus().Publish(evt.CachingResultCacheChanged, r.resultCache.TotalCount()) if publish && r.redisClient != nil { res := *response.Res diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index e9a6aebbc..53acc2405 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -3,6 +3,7 @@ package resolver import ( "time" + "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/evt" . "github.com/0xERR0R/blocky/helpertest" @@ -12,7 +13,6 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/creasty/defaults" - "github.com/0xERR0R/go-cache" "github.com/miekg/dns" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -58,13 +58,16 @@ var _ = Describe("CachingResolver", func() { PrefetchExpires: config.Duration(time.Minute * 120), PrefetchThreshold: 5, } + mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 1, dns.TypeA, "123.122.121.120") }) It("should prefetch domain if query count > threshold", func() { // prepare resolver, set smaller caching times for testing prefetchThreshold := 5 - sut.(*CachingResolver).resultCache = cache.New(25*time.Millisecond, 15*time.Millisecond) - configurePrefetching(sut.(*CachingResolver), &sutConfig) + configureCaches(sut.(*CachingResolver), &sutConfig) + sut.(*CachingResolver).resultCache = expirationcache.NewCache( + expirationcache.WithCleanUpInterval(100*time.Millisecond), + expirationcache.WithOnExpiredFn(sut.(*CachingResolver).onExpired)) prefetchedCnt := 0 _ = Bus().SubscribeOnce(CachingDomainsToPrefetchCountChanged, func(cnt int) { @@ -93,8 +96,8 @@ var _ = Describe("CachingResolver", func() { // now query again > threshold for i := 0; i < prefetchThreshold; i++ { _, _ = sut.Resolve(newRequest("example.com.", dns.TypeA)) - } + Eventually(func(g Gomega) { // now is this domain prefetched g.Expect(domainPrefetched).Should(Equal("example.com")) @@ -102,7 +105,7 @@ var _ = Describe("CachingResolver", func() { // and it should hit from prefetch cache _, _ = sut.Resolve(newRequest("example.com.", dns.TypeA)) g.Expect(prefetchHitDomain).Should(Equal("example.com")) - }, "50ms").Should(Succeed()) + }, "2s").Should(Succeed()) }) }) @@ -420,12 +423,15 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - resp, err = sut.Resolve(newRequest("google.de.", dns.TypeMX)) - Expect(err).Should(Succeed()) - Expect(resp.RType).Should(Equal(ResponseTypeCACHED)) - Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess)) - Expect(m.Calls).Should(HaveLen(1)) - Expect(resp.Res.Answer).Should(BeDNSRecord("google.de.", dns.TypeMX, 179, "alt1.aspmx.l.google.com.")) + Eventually(func(g Gomega) { + resp, err = sut.Resolve(newRequest("google.de.", dns.TypeMX)) + g.Expect(err).Should(Succeed()) + g.Expect(resp.RType).Should(Equal(ResponseTypeCACHED)) + g.Expect(resp.Res.Rcode).Should(Equal(dns.RcodeSuccess)) + g.Expect(m.Calls).Should(HaveLen(1)) + g.Expect(resp.Res.Answer).Should(BeDNSRecord("google.de.", dns.TypeMX, 179, "alt1.aspmx.l.google.com.")) + }, "1s").Should(Succeed()) + }) }) }) diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index 597d93249..96444ffef 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -6,18 +6,18 @@ import ( "strings" "time" + "github.com/0xERR0R/blocky/cache/expirationcache" "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" - "github.com/0xERR0R/go-cache" "github.com/miekg/dns" "github.com/sirupsen/logrus" ) // ClientNamesResolver tries to determine client name by asking responsible DNS server via rDNS (reverse lookup) type ClientNamesResolver struct { - cache *cache.Cache + cache expirationcache.ExpiringCache externalResolver Resolver singleNameOrder []uint clientIPMapping map[string][]net.IP @@ -32,7 +32,7 @@ func NewClientNamesResolver(cfg config.ClientLookupConfig) ChainedResolver { } return &ClientNamesResolver{ - cache: cache.New(1*time.Hour, 1*time.Hour), + cache: expirationcache.NewCache(expirationcache.WithCleanUpInterval(time.Hour)), externalResolver: r, singleNameOrder: cfg.SingleNameOrder, clientIPMapping: cfg.ClientnameIPMapping, @@ -48,7 +48,7 @@ func (r *ClientNamesResolver) Configuration() (result []string) { result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver)) } - result = append(result, fmt.Sprintf("cache item count = %d", r.cache.ItemCount())) + result = append(result, fmt.Sprintf("cache item count = %d", r.cache.TotalCount())) if len(r.clientIPMapping) > 0 { result = append(result, "client IP mapping:") @@ -86,16 +86,16 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string { return []string{} } - c, found := r.cache.Get(ip.String()) + c, _ := r.cache.Get(ip.String()) - if found { + if c != nil { if t, ok := c.([]string); ok { return t } } names := r.resolveClientNames(ip, withPrefix(request.Log, "client_names_resolver")) - r.cache.Set(ip.String(), names, cache.DefaultExpiration) + r.cache.Put(ip.String(), names, time.Hour) return names } @@ -169,5 +169,5 @@ func (r *ClientNamesResolver) getNameFromIPMapping(ip net.IP, result []string) [ // FlushCache reset client name cache func (r *ClientNamesResolver) FlushCache() { - r.cache.Flush() + r.cache.Clear() }