commit 3ff1391a9dfaeee4102654a6a871ef843e13b639
parent 66f09a8d930d6768ef86449b46904bef0745e6ec
Author: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>
Date: Mon, 1 May 2023 11:36:46 +0100
[performance] replace domain block cache with an in-memory radix trie (#1714)
* replace domain block cache with an in-memory radix tree
Signed-off-by: kim <grufwub@gmail.com>
* fix domain block cache init
Signed-off-by: kim <grufwub@gmail.com>
---------
Signed-off-by: kim <grufwub@gmail.com>
Diffstat:
3 files changed, 164 insertions(+), 120 deletions(-)
diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go
@@ -19,151 +19,206 @@ package domain
import (
"fmt"
- "time"
+ "strings"
+ "sync/atomic"
+ "unsafe"
- "codeberg.org/gruf/go-cache/v3/ttl"
- "github.com/miekg/dns"
+ "golang.org/x/exp/slices"
)
// BlockCache provides a means of caching domain blocks in memory to reduce load
// on an underlying storage mechanism, e.g. a database.
//
-// It consists of a TTL primary cache that stores calculated domain string to block results,
-// that on cache miss is filled by calculating block status by iterating over a list of all of
-// the domain blocks stored in memory. This reduces CPU usage required by not need needing to
-// iterate through a possible 100-1000s long block list, while saving memory by having a primary
-// cache of limited size that evicts stale entries. The raw list of all domain blocks should in
-// most cases be negligible when it comes to memory usage.
-//
// The in-memory block list is kept up-to-date by means of a passed loader function during every
// call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to
-// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate
-// the cache, e.g. when a domain block is added / deleted from the database. It will drop the current
-// list of domain blocks and clear all entries from the primary cache.
+// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to
+// invalidate the cache, e.g. when a domain block is added / deleted from the database.
type BlockCache struct {
- pcache *ttl.Cache[string, bool] // primary cache of domains -> block results
- blocks []block // raw list of all domain blocks, nil => not loaded.
-}
-
-// New returns a new initialized BlockCache instance with given primary cache capacity and TTL.
-func New(pcap int, pttl time.Duration) *BlockCache {
- c := new(BlockCache)
- c.pcache = new(ttl.Cache[string, bool])
- c.pcache.Init(0, pcap, pttl)
- return c
-}
-
-// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started.
-func (b *BlockCache) Start(pfreq time.Duration) bool {
- return b.pcache.Start(pfreq)
-}
-
-// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped.
-func (b *BlockCache) Stop() bool {
- return b.pcache.Stop()
+ // atomically updated ptr value to the
+ // current domain block cache radix trie.
+ rootptr unsafe.Pointer
}
// IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it.
-// NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock.
func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) {
- var blocked bool
-
- // Acquire cache lock
- b.pcache.Lock()
- defer b.pcache.Unlock()
-
- // Check primary cache for result
- entry, ok := b.pcache.Cache.Get(domain)
- if ok {
- return entry.Value, nil
- }
+ // Load the current root pointer value.
+ ptr := atomic.LoadPointer(&b.rootptr)
- if b.blocks == nil {
- // Cache is not hydrated
+ if ptr == nil {
+ // Cache is not hydrated.
//
- // Load domains from callback
+ // Load domains from callback.
domains, err := load()
if err != nil {
return false, fmt.Errorf("error reloading cache: %w", err)
}
- // Drop all domain blocks and recreate
- b.blocks = make([]block, len(domains))
+ // Allocate new radix trie
+ // node to store matches.
+ root := new(root)
- for i, domain := range domains {
- // Store pre-split labels for each domain block
- b.blocks[i].labels = dns.SplitDomainName(domain)
+ // Add each domain to the trie.
+ for _, domain := range domains {
+ root.Add(domain)
}
- }
- // Split domain into it separate labels
- labels := dns.SplitDomainName(domain)
+ // Sort the trie.
+ root.Sort()
- // Compare this to our stored blocks
- for _, block := range b.blocks {
- if block.Blocks(labels) {
- blocked = true
- break
- }
+ // Store the new node ptr.
+ ptr = unsafe.Pointer(root)
+ atomic.StorePointer(&b.rootptr, ptr)
}
- // Store block result in primary cache
- b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{
- Key: domain,
- Value: blocked,
- Expiry: time.Now().Add(b.pcache.TTL),
- })
-
- return blocked, nil
+ // Look for a match in the trie node.
+ return (*root)(ptr).Match(domain), nil
}
-// Clear will drop the currently loaded domain list, and clear the primary cache.
-// This will trigger a reload on next call to .IsBlocked().
+// Clear will drop the currently loaded domain list,
+// triggering a reload on next call to .IsBlocked().
func (b *BlockCache) Clear() {
- // Drop all blocks.
- b.pcache.Lock()
- b.blocks = nil
- b.pcache.Unlock()
-
- // Clear needs to be done _outside_ of
- // lock, as also acquires a mutex lock.
- b.pcache.Clear()
+ atomic.StorePointer(&b.rootptr, nil)
+}
+
+// root is the root node in the domain
+// block cache radix trie. this is the
+// singular access point to the trie.
+type root struct{ root node }
+
+// Add will add the given domain to the radix trie.
+func (r *root) Add(domain string) {
+ r.root.add(strings.Split(domain, "."))
+}
+
+// Match will return whether the given domain matches
+// an existing stored domain block in this radix trie.
+func (r *root) Match(domain string) bool {
+ return r.root.match(strings.Split(domain, "."))
+}
+
+// Sort will sort the entire radix trie ensuring that
+// child nodes are stored in alphabetical order. This
+// MUST be done to finalize the block cache in order
+// to speed up the binary search of node child parts.
+func (r *root) Sort() {
+ r.root.sort()
}
-// block represents a domain block, and stores the
-// deconstructed labels of a singular domain block.
-// e.g. []string{"gts", "superseriousbusiness", "org"}.
-type block struct {
- labels []string
+type node struct {
+ part string
+ child []*node
+}
+
+func (n *node) add(parts []string) {
+ if len(parts) == 0 {
+ panic("invalid domain")
+ }
+
+ for {
+ // Pop next domain part.
+ i := len(parts) - 1
+ part := parts[i]
+ parts = parts[:i]
+
+ var nn *node
+
+ // Look for existing child node
+ // that matches next domain part.
+ for _, child := range n.child {
+ if child.part == part {
+ nn = child
+ break
+ }
+ }
+
+ if nn == nil {
+ // Alloc new child node.
+ nn = &node{part: part}
+ n.child = append(n.child, nn)
+ }
+
+ if len(parts) == 0 {
+ // Drop all children here as
+ // this is a higher-level block
+ // than that we previously had.
+ nn.child = nil
+ return
+ }
+
+ // Re-iter with
+ // child node.
+ n = nn
+ }
}
-// Blocks checks whether the separated domain labels of an
-// incoming domain matches the stored (receiving struct) block.
-func (b block) Blocks(labels []string) bool {
- // Calculate length difference
- d := len(labels) - len(b.labels)
- if d < 0 {
+func (n *node) match(parts []string) bool {
+ if len(parts) == 0 {
+ // Invalid domain.
return false
}
- // Iterate backwards through domain block's
- // labels, omparing against the incoming domain's.
- //
- // So for the following input:
- // labels = []string{"mail", "google", "com"}
- // b.labels = []string{"google", "com"}
- //
- // These would be matched in reverse order along
- // the entirety of the block object's labels:
- // "com" => match
- // "google" => match
- //
- // And so would reach the end and return true.
- for i := len(b.labels) - 1; i >= 0; i-- {
- if b.labels[i] != labels[i+d] {
+ for {
+ // Pop next domain part.
+ i := len(parts) - 1
+ part := parts[i]
+ parts = parts[:i]
+
+ // Look for existing child
+ // that matches next part.
+ nn := n.getChild(part)
+
+ if nn == nil {
+ // No match :(
return false
}
+
+ if len(nn.child) == 0 {
+ // It's a match!
+ return true
+ }
+
+ // Re-iter with
+ // child node.
+ n = nn
+ }
+}
+
+// getChild fetches child node with given domain part string
+// using a binary search. THIS ASSUMES CHILDREN ARE SORTED.
+func (n *node) getChild(part string) *node {
+ i, j := 0, len(n.child)
+
+ for i < j {
+ // avoid overflow when computing h
+ h := int(uint(i+j) >> 1)
+ // i ≤ h < j
+
+ if n.child[h].part < part {
+ // preserves:
+ // n.child[i-1].part != part
+ i = h + 1
+ } else {
+ // preserves:
+ // n.child[h].part == part
+ j = h
+ }
+ }
+
+ if i >= len(n.child) || n.child[i].part != part {
+ return nil // no match
}
- return true
+ return n.child[i]
+}
+
+func (n *node) sort() {
+ // Sort this node's slice of child nodes.
+ slices.SortFunc(n.child, func(i, j *node) bool {
+ return i.part < j.part
+ })
+
+ // Sort each child node's children.
+ for _, child := range n.child {
+ child.sort()
+ }
}
diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go
@@ -20,13 +20,12 @@ package domain_test
import (
"errors"
"testing"
- "time"
"github.com/superseriousbusiness/gotosocial/internal/cache/domain"
)
func TestBlockCache(t *testing.T) {
- c := domain.New(100, time.Second)
+ c := new(domain.BlockCache)
blocks := []string{
"google.com",
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
@@ -72,12 +72,6 @@ func (c *GTSCaches) Init() {
func (c *GTSCaches) Start() {
tryStart(c.account, config.GetCacheGTSAccountSweepFreq())
tryStart(c.block, config.GetCacheGTSBlockSweepFreq())
- tryUntil("starting domain block cache", 5, func() bool {
- if sweep := config.GetCacheGTSDomainBlockSweepFreq(); sweep > 0 {
- return c.domainBlock.Start(sweep)
- }
- return true
- })
tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
@@ -102,7 +96,6 @@ func (c *GTSCaches) Start() {
func (c *GTSCaches) Stop() {
tryStop(c.account, config.GetCacheGTSAccountSweepFreq())
tryStop(c.block, config.GetCacheGTSBlockSweepFreq())
- tryUntil("stopping domain block cache", 5, c.domainBlock.Stop)
tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
@@ -233,10 +226,7 @@ func (c *GTSCaches) initBlock() {
}
func (c *GTSCaches) initDomainBlock() {
- c.domainBlock = domain.New(
- config.GetCacheGTSDomainBlockMaxSize(),
- config.GetCacheGTSDomainBlockTTL(),
- )
+ c.domainBlock = new(domain.BlockCache)
}
func (c *GTSCaches) initEmoji() {