domain.go (6117B)
1 // GoToSocial 2 // Copyright (C) GoToSocial Authors admin@gotosocial.org 3 // SPDX-License-Identifier: AGPL-3.0-or-later 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package domain 19 20 import ( 21 "fmt" 22 "strings" 23 "sync/atomic" 24 "unsafe" 25 26 "golang.org/x/exp/slices" 27 ) 28 29 // BlockCache provides a means of caching domain blocks in memory to reduce load 30 // on an underlying storage mechanism, e.g. a database. 31 // 32 // The in-memory block list is kept up-to-date by means of a passed loader function during every 33 // call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to 34 // hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to 35 // invalidate the cache, e.g. when a domain block is added / deleted from the database. 36 type BlockCache struct { 37 // atomically updated ptr value to the 38 // current domain block cache radix trie. 39 rootptr unsafe.Pointer 40 } 41 42 // IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it. 43 func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) { 44 // Load the current root pointer value. 45 ptr := atomic.LoadPointer(&b.rootptr) 46 47 if ptr == nil { 48 // Cache is not hydrated. 49 // 50 // Load domains from callback. 51 domains, err := load() 52 if err != nil { 53 return false, fmt.Errorf("error reloading cache: %w", err) 54 } 55 56 // Allocate new radix trie 57 // node to store matches. 58 root := new(root) 59 60 // Add each domain to the trie. 61 for _, domain := range domains { 62 root.Add(domain) 63 } 64 65 // Sort the trie. 66 root.Sort() 67 68 // Store the new node ptr. 69 ptr = unsafe.Pointer(root) 70 atomic.StorePointer(&b.rootptr, ptr) 71 } 72 73 // Look for a match in the trie node. 74 return (*root)(ptr).Match(domain), nil 75 } 76 77 // Clear will drop the currently loaded domain list, 78 // triggering a reload on next call to .IsBlocked(). 79 func (b *BlockCache) Clear() { 80 atomic.StorePointer(&b.rootptr, nil) 81 } 82 83 // String returns a string representation of stored domains in block cache. 84 func (b *BlockCache) String() string { 85 if ptr := atomic.LoadPointer(&b.rootptr); ptr != nil { 86 return (*root)(ptr).String() 87 } 88 return "<empty>" 89 } 90 91 // root is the root node in the domain 92 // block cache radix trie. this is the 93 // singular access point to the trie. 94 type root struct{ root node } 95 96 // Add will add the given domain to the radix trie. 97 func (r *root) Add(domain string) { 98 r.root.add(strings.Split(domain, ".")) 99 } 100 101 // Match will return whether the given domain matches 102 // an existing stored domain block in this radix trie. 103 func (r *root) Match(domain string) bool { 104 return r.root.match(strings.Split(domain, ".")) 105 } 106 107 // Sort will sort the entire radix trie ensuring that 108 // child nodes are stored in alphabetical order. This 109 // MUST be done to finalize the block cache in order 110 // to speed up the binary search of node child parts. 111 func (r *root) Sort() { 112 r.root.sort() 113 } 114 115 // String returns a string representation of node (and its descendants). 116 func (r *root) String() string { 117 buf := new(strings.Builder) 118 r.root.writestr(buf, "") 119 return buf.String() 120 } 121 122 type node struct { 123 part string 124 child []*node 125 } 126 127 func (n *node) add(parts []string) { 128 if len(parts) == 0 { 129 panic("invalid domain") 130 } 131 132 for { 133 // Pop next domain part. 134 i := len(parts) - 1 135 part := parts[i] 136 parts = parts[:i] 137 138 var nn *node 139 140 // Look for existing child node 141 // that matches next domain part. 142 for _, child := range n.child { 143 if child.part == part { 144 nn = child 145 break 146 } 147 } 148 149 if nn == nil { 150 // Alloc new child node. 151 nn = &node{part: part} 152 n.child = append(n.child, nn) 153 } 154 155 if len(parts) == 0 { 156 // Drop all children here as 157 // this is a higher-level block 158 // than that we previously had. 159 nn.child = nil 160 return 161 } 162 163 // Re-iter with 164 // child node. 165 n = nn 166 } 167 } 168 169 func (n *node) match(parts []string) bool { 170 for len(parts) > 0 { 171 // Pop next domain part. 172 i := len(parts) - 1 173 part := parts[i] 174 parts = parts[:i] 175 176 // Look for existing child 177 // that matches next part. 178 nn := n.getChild(part) 179 180 if nn == nil { 181 // No match :( 182 return false 183 } 184 185 if len(nn.child) == 0 { 186 // It's a match! 187 return true 188 } 189 190 // Re-iter with 191 // child node. 192 n = nn 193 } 194 195 // Ran out of parts 196 // without a match. 197 return false 198 } 199 200 // getChild fetches child node with given domain part string 201 // using a binary search. THIS ASSUMES CHILDREN ARE SORTED. 202 func (n *node) getChild(part string) *node { 203 i, j := 0, len(n.child) 204 205 for i < j { 206 // avoid overflow when computing h 207 h := int(uint(i+j) >> 1) 208 // i ≤ h < j 209 210 if n.child[h].part < part { 211 // preserves: 212 // n.child[i-1].part != part 213 i = h + 1 214 } else { 215 // preserves: 216 // n.child[h].part == part 217 j = h 218 } 219 } 220 221 if i >= len(n.child) || n.child[i].part != part { 222 return nil // no match 223 } 224 225 return n.child[i] 226 } 227 228 func (n *node) sort() { 229 // Sort this node's slice of child nodes. 230 slices.SortFunc(n.child, func(i, j *node) bool { 231 return i.part < j.part 232 }) 233 234 // Sort each child node's children. 235 for _, child := range n.child { 236 child.sort() 237 } 238 } 239 240 func (n *node) writestr(buf *strings.Builder, prefix string) { 241 if prefix != "" { 242 // Suffix joining '.' 243 prefix += "." 244 } 245 246 // Append current part. 247 prefix += n.part 248 249 // Dump current prefix state. 250 buf.WriteString(prefix) 251 buf.WriteByte('\n') 252 253 // Iterate through node children. 254 for _, child := range n.child { 255 child.writestr(buf, prefix) 256 } 257 }