trie.go (3510B)
1 package utilities 2 3 import ( 4 "sort" 5 ) 6 7 // DoubleArray is a Double Array implementation of trie on sequences of strings. 8 type DoubleArray struct { 9 // Encoding keeps an encoding from string to int 10 Encoding map[string]int 11 // Base is the base array of Double Array 12 Base []int 13 // Check is the check array of Double Array 14 Check []int 15 } 16 17 // NewDoubleArray builds a DoubleArray from a set of sequences of strings. 18 func NewDoubleArray(seqs [][]string) *DoubleArray { 19 da := &DoubleArray{Encoding: make(map[string]int)} 20 if len(seqs) == 0 { 21 return da 22 } 23 24 encoded := registerTokens(da, seqs) 25 sort.Sort(byLex(encoded)) 26 27 root := node{row: -1, col: -1, left: 0, right: len(encoded)} 28 addSeqs(da, encoded, 0, root) 29 30 for i := len(da.Base); i > 0; i-- { 31 if da.Check[i-1] != 0 { 32 da.Base = da.Base[:i] 33 da.Check = da.Check[:i] 34 break 35 } 36 } 37 return da 38 } 39 40 func registerTokens(da *DoubleArray, seqs [][]string) [][]int { 41 var result [][]int 42 for _, seq := range seqs { 43 var encoded []int 44 for _, token := range seq { 45 if _, ok := da.Encoding[token]; !ok { 46 da.Encoding[token] = len(da.Encoding) 47 } 48 encoded = append(encoded, da.Encoding[token]) 49 } 50 result = append(result, encoded) 51 } 52 for i := range result { 53 result[i] = append(result[i], len(da.Encoding)) 54 } 55 return result 56 } 57 58 type node struct { 59 row, col int 60 left, right int 61 } 62 63 func (n node) value(seqs [][]int) int { 64 return seqs[n.row][n.col] 65 } 66 67 func (n node) children(seqs [][]int) []*node { 68 var result []*node 69 lastVal := int(-1) 70 last := new(node) 71 for i := n.left; i < n.right; i++ { 72 if lastVal == seqs[i][n.col+1] { 73 continue 74 } 75 last.right = i 76 last = &node{ 77 row: i, 78 col: n.col + 1, 79 left: i, 80 } 81 result = append(result, last) 82 } 83 last.right = n.right 84 return result 85 } 86 87 func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) { 88 ensureSize(da, pos) 89 90 children := n.children(seqs) 91 var i int 92 for i = 1; ; i++ { 93 ok := func() bool { 94 for _, child := range children { 95 code := child.value(seqs) 96 j := i + code 97 ensureSize(da, j) 98 if da.Check[j] != 0 { 99 return false 100 } 101 } 102 return true 103 }() 104 if ok { 105 break 106 } 107 } 108 da.Base[pos] = i 109 for _, child := range children { 110 code := child.value(seqs) 111 j := i + code 112 da.Check[j] = pos + 1 113 } 114 terminator := len(da.Encoding) 115 for _, child := range children { 116 code := child.value(seqs) 117 if code == terminator { 118 continue 119 } 120 j := i + code 121 addSeqs(da, seqs, j, *child) 122 } 123 } 124 125 func ensureSize(da *DoubleArray, i int) { 126 for i >= len(da.Base) { 127 da.Base = append(da.Base, make([]int, len(da.Base)+1)...) 128 da.Check = append(da.Check, make([]int, len(da.Check)+1)...) 129 } 130 } 131 132 type byLex [][]int 133 134 func (l byLex) Len() int { return len(l) } 135 func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] } 136 func (l byLex) Less(i, j int) bool { 137 si := l[i] 138 sj := l[j] 139 var k int 140 for k = 0; k < len(si) && k < len(sj); k++ { 141 if si[k] < sj[k] { 142 return true 143 } 144 if si[k] > sj[k] { 145 return false 146 } 147 } 148 return k < len(sj) 149 } 150 151 // HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence. 152 func (da *DoubleArray) HasCommonPrefix(seq []string) bool { 153 if len(da.Base) == 0 { 154 return false 155 } 156 157 var i int 158 for _, t := range seq { 159 code, ok := da.Encoding[t] 160 if !ok { 161 break 162 } 163 j := da.Base[i] + code 164 if len(da.Check) <= j || da.Check[j] != i+1 { 165 break 166 } 167 i = j 168 } 169 j := da.Base[i] + len(da.Encoding) 170 if len(da.Check) <= j || da.Check[j] != i+1 { 171 return false 172 } 173 return true 174 }