flag_groups.go (7197B)
1 // Copyright 2013-2023 The Cobra Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cobra 16 17 import ( 18 "fmt" 19 "sort" 20 "strings" 21 22 flag "github.com/spf13/pflag" 23 ) 24 25 const ( 26 requiredAsGroup = "cobra_annotation_required_if_others_set" 27 mutuallyExclusive = "cobra_annotation_mutually_exclusive" 28 ) 29 30 // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors 31 // if the command is invoked with a subset (but not all) of the given flags. 32 func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { 33 c.mergePersistentFlags() 34 for _, v := range flagNames { 35 f := c.Flags().Lookup(v) 36 if f == nil { 37 panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) 38 } 39 if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { 40 // Only errs if the flag isn't found. 41 panic(err) 42 } 43 } 44 } 45 46 // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors 47 // if the command is invoked with more than one flag from the given set of flags. 48 func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { 49 c.mergePersistentFlags() 50 for _, v := range flagNames { 51 f := c.Flags().Lookup(v) 52 if f == nil { 53 panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) 54 } 55 // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. 56 if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { 57 panic(err) 58 } 59 } 60 } 61 62 // ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the 63 // first error encountered. 64 func (c *Command) ValidateFlagGroups() error { 65 if c.DisableFlagParsing { 66 return nil 67 } 68 69 flags := c.Flags() 70 71 // groupStatus format is the list of flags as a unique ID, 72 // then a map of each flag name and whether it is set or not. 73 groupStatus := map[string]map[string]bool{} 74 mutuallyExclusiveGroupStatus := map[string]map[string]bool{} 75 flags.VisitAll(func(pflag *flag.Flag) { 76 processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) 77 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) 78 }) 79 80 if err := validateRequiredFlagGroups(groupStatus); err != nil { 81 return err 82 } 83 if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { 84 return err 85 } 86 return nil 87 } 88 89 func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { 90 for _, fname := range flagnames { 91 f := fs.Lookup(fname) 92 if f == nil { 93 return false 94 } 95 } 96 return true 97 } 98 99 func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { 100 groupInfo, found := pflag.Annotations[annotation] 101 if found { 102 for _, group := range groupInfo { 103 if groupStatus[group] == nil { 104 flagnames := strings.Split(group, " ") 105 106 // Only consider this flag group at all if all the flags are defined. 107 if !hasAllFlags(flags, flagnames...) { 108 continue 109 } 110 111 groupStatus[group] = map[string]bool{} 112 for _, name := range flagnames { 113 groupStatus[group][name] = false 114 } 115 } 116 117 groupStatus[group][pflag.Name] = pflag.Changed 118 } 119 } 120 } 121 122 func validateRequiredFlagGroups(data map[string]map[string]bool) error { 123 keys := sortedKeys(data) 124 for _, flagList := range keys { 125 flagnameAndStatus := data[flagList] 126 127 unset := []string{} 128 for flagname, isSet := range flagnameAndStatus { 129 if !isSet { 130 unset = append(unset, flagname) 131 } 132 } 133 if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { 134 continue 135 } 136 137 // Sort values, so they can be tested/scripted against consistently. 138 sort.Strings(unset) 139 return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) 140 } 141 142 return nil 143 } 144 145 func validateExclusiveFlagGroups(data map[string]map[string]bool) error { 146 keys := sortedKeys(data) 147 for _, flagList := range keys { 148 flagnameAndStatus := data[flagList] 149 var set []string 150 for flagname, isSet := range flagnameAndStatus { 151 if isSet { 152 set = append(set, flagname) 153 } 154 } 155 if len(set) == 0 || len(set) == 1 { 156 continue 157 } 158 159 // Sort values, so they can be tested/scripted against consistently. 160 sort.Strings(set) 161 return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) 162 } 163 return nil 164 } 165 166 func sortedKeys(m map[string]map[string]bool) []string { 167 keys := make([]string, len(m)) 168 i := 0 169 for k := range m { 170 keys[i] = k 171 i++ 172 } 173 sort.Strings(keys) 174 return keys 175 } 176 177 // enforceFlagGroupsForCompletion will do the following: 178 // - when a flag in a group is present, other flags in the group will be marked required 179 // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden 180 // This allows the standard completion logic to behave appropriately for flag groups 181 func (c *Command) enforceFlagGroupsForCompletion() { 182 if c.DisableFlagParsing { 183 return 184 } 185 186 flags := c.Flags() 187 groupStatus := map[string]map[string]bool{} 188 mutuallyExclusiveGroupStatus := map[string]map[string]bool{} 189 c.Flags().VisitAll(func(pflag *flag.Flag) { 190 processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) 191 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) 192 }) 193 194 // If a flag that is part of a group is present, we make all the other flags 195 // of that group required so that the shell completion suggests them automatically 196 for flagList, flagnameAndStatus := range groupStatus { 197 for _, isSet := range flagnameAndStatus { 198 if isSet { 199 // One of the flags of the group is set, mark the other ones as required 200 for _, fName := range strings.Split(flagList, " ") { 201 _ = c.MarkFlagRequired(fName) 202 } 203 } 204 } 205 } 206 207 // If a flag that is mutually exclusive to others is present, we hide the other 208 // flags of that group so the shell completion does not suggest them 209 for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { 210 for flagName, isSet := range flagnameAndStatus { 211 if isSet { 212 // One of the flags of the mutually exclusive group is set, mark the other ones as hidden 213 // Don't mark the flag that is already set as hidden because it may be an 214 // array or slice flag and therefore must continue being suggested 215 for _, fName := range strings.Split(flagList, " ") { 216 if fName != flagName { 217 flag := c.Flags().Lookup(fName) 218 flag.Hidden = true 219 } 220 } 221 } 222 } 223 } 224 }