gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

domain.go (6117B)


      1 // GoToSocial
      2 // Copyright (C) GoToSocial Authors admin@gotosocial.org
      3 // SPDX-License-Identifier: AGPL-3.0-or-later
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 
     18 package domain
     19 
     20 import (
     21 	"fmt"
     22 	"strings"
     23 	"sync/atomic"
     24 	"unsafe"
     25 
     26 	"golang.org/x/exp/slices"
     27 )
     28 
     29 // BlockCache provides a means of caching domain blocks in memory to reduce load
     30 // on an underlying storage mechanism, e.g. a database.
     31 //
     32 // The in-memory block list is kept up-to-date by means of a passed loader function during every
     33 // call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to
     34 // hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to
     35 // invalidate the cache, e.g. when a domain block is added / deleted from the database.
     36 type BlockCache struct {
     37 	// atomically updated ptr value to the
     38 	// current domain block cache radix trie.
     39 	rootptr unsafe.Pointer
     40 }
     41 
     42 // IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it.
     43 func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) {
     44 	// Load the current root pointer value.
     45 	ptr := atomic.LoadPointer(&b.rootptr)
     46 
     47 	if ptr == nil {
     48 		// Cache is not hydrated.
     49 		//
     50 		// Load domains from callback.
     51 		domains, err := load()
     52 		if err != nil {
     53 			return false, fmt.Errorf("error reloading cache: %w", err)
     54 		}
     55 
     56 		// Allocate new radix trie
     57 		// node to store matches.
     58 		root := new(root)
     59 
     60 		// Add each domain to the trie.
     61 		for _, domain := range domains {
     62 			root.Add(domain)
     63 		}
     64 
     65 		// Sort the trie.
     66 		root.Sort()
     67 
     68 		// Store the new node ptr.
     69 		ptr = unsafe.Pointer(root)
     70 		atomic.StorePointer(&b.rootptr, ptr)
     71 	}
     72 
     73 	// Look for a match in the trie node.
     74 	return (*root)(ptr).Match(domain), nil
     75 }
     76 
     77 // Clear will drop the currently loaded domain list,
     78 // triggering a reload on next call to .IsBlocked().
     79 func (b *BlockCache) Clear() {
     80 	atomic.StorePointer(&b.rootptr, nil)
     81 }
     82 
     83 // String returns a string representation of stored domains in block cache.
     84 func (b *BlockCache) String() string {
     85 	if ptr := atomic.LoadPointer(&b.rootptr); ptr != nil {
     86 		return (*root)(ptr).String()
     87 	}
     88 	return "<empty>"
     89 }
     90 
     91 // root is the root node in the domain
     92 // block cache radix trie. this is the
     93 // singular access point to the trie.
     94 type root struct{ root node }
     95 
     96 // Add will add the given domain to the radix trie.
     97 func (r *root) Add(domain string) {
     98 	r.root.add(strings.Split(domain, "."))
     99 }
    100 
    101 // Match will return whether the given domain matches
    102 // an existing stored domain block in this radix trie.
    103 func (r *root) Match(domain string) bool {
    104 	return r.root.match(strings.Split(domain, "."))
    105 }
    106 
    107 // Sort will sort the entire radix trie ensuring that
    108 // child nodes are stored in alphabetical order. This
    109 // MUST be done to finalize the block cache in order
    110 // to speed up the binary search of node child parts.
    111 func (r *root) Sort() {
    112 	r.root.sort()
    113 }
    114 
    115 // String returns a string representation of node (and its descendants).
    116 func (r *root) String() string {
    117 	buf := new(strings.Builder)
    118 	r.root.writestr(buf, "")
    119 	return buf.String()
    120 }
    121 
    122 type node struct {
    123 	part  string
    124 	child []*node
    125 }
    126 
    127 func (n *node) add(parts []string) {
    128 	if len(parts) == 0 {
    129 		panic("invalid domain")
    130 	}
    131 
    132 	for {
    133 		// Pop next domain part.
    134 		i := len(parts) - 1
    135 		part := parts[i]
    136 		parts = parts[:i]
    137 
    138 		var nn *node
    139 
    140 		// Look for existing child node
    141 		// that matches next domain part.
    142 		for _, child := range n.child {
    143 			if child.part == part {
    144 				nn = child
    145 				break
    146 			}
    147 		}
    148 
    149 		if nn == nil {
    150 			// Alloc new child node.
    151 			nn = &node{part: part}
    152 			n.child = append(n.child, nn)
    153 		}
    154 
    155 		if len(parts) == 0 {
    156 			// Drop all children here as
    157 			// this is a higher-level block
    158 			// than that we previously had.
    159 			nn.child = nil
    160 			return
    161 		}
    162 
    163 		// Re-iter with
    164 		// child node.
    165 		n = nn
    166 	}
    167 }
    168 
    169 func (n *node) match(parts []string) bool {
    170 	for len(parts) > 0 {
    171 		// Pop next domain part.
    172 		i := len(parts) - 1
    173 		part := parts[i]
    174 		parts = parts[:i]
    175 
    176 		// Look for existing child
    177 		// that matches next part.
    178 		nn := n.getChild(part)
    179 
    180 		if nn == nil {
    181 			// No match :(
    182 			return false
    183 		}
    184 
    185 		if len(nn.child) == 0 {
    186 			// It's a match!
    187 			return true
    188 		}
    189 
    190 		// Re-iter with
    191 		// child node.
    192 		n = nn
    193 	}
    194 
    195 	// Ran out of parts
    196 	// without a match.
    197 	return false
    198 }
    199 
    200 // getChild fetches child node with given domain part string
    201 // using a binary search. THIS ASSUMES CHILDREN ARE SORTED.
    202 func (n *node) getChild(part string) *node {
    203 	i, j := 0, len(n.child)
    204 
    205 	for i < j {
    206 		// avoid overflow when computing h
    207 		h := int(uint(i+j) >> 1)
    208 		// i ≤ h < j
    209 
    210 		if n.child[h].part < part {
    211 			// preserves:
    212 			// n.child[i-1].part != part
    213 			i = h + 1
    214 		} else {
    215 			// preserves:
    216 			// n.child[h].part == part
    217 			j = h
    218 		}
    219 	}
    220 
    221 	if i >= len(n.child) || n.child[i].part != part {
    222 		return nil // no match
    223 	}
    224 
    225 	return n.child[i]
    226 }
    227 
    228 func (n *node) sort() {
    229 	// Sort this node's slice of child nodes.
    230 	slices.SortFunc(n.child, func(i, j *node) bool {
    231 		return i.part < j.part
    232 	})
    233 
    234 	// Sort each child node's children.
    235 	for _, child := range n.child {
    236 		child.sort()
    237 	}
    238 }
    239 
    240 func (n *node) writestr(buf *strings.Builder, prefix string) {
    241 	if prefix != "" {
    242 		// Suffix joining '.'
    243 		prefix += "."
    244 	}
    245 
    246 	// Append current part.
    247 	prefix += n.part
    248 
    249 	// Dump current prefix state.
    250 	buf.WriteString(prefix)
    251 	buf.WriteByte('\n')
    252 
    253 	// Iterate through node children.
    254 	for _, child := range n.child {
    255 		child.writestr(buf, prefix)
    256 	}
    257 }