gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

flag_groups.go (7197B)


      1 // Copyright 2013-2023 The Cobra Authors
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 package cobra
     16 
     17 import (
     18 	"fmt"
     19 	"sort"
     20 	"strings"
     21 
     22 	flag "github.com/spf13/pflag"
     23 )
     24 
     25 const (
     26 	requiredAsGroup   = "cobra_annotation_required_if_others_set"
     27 	mutuallyExclusive = "cobra_annotation_mutually_exclusive"
     28 )
     29 
     30 // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
     31 // if the command is invoked with a subset (but not all) of the given flags.
     32 func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
     33 	c.mergePersistentFlags()
     34 	for _, v := range flagNames {
     35 		f := c.Flags().Lookup(v)
     36 		if f == nil {
     37 			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
     38 		}
     39 		if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
     40 			// Only errs if the flag isn't found.
     41 			panic(err)
     42 		}
     43 	}
     44 }
     45 
     46 // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
     47 // if the command is invoked with more than one flag from the given set of flags.
     48 func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
     49 	c.mergePersistentFlags()
     50 	for _, v := range flagNames {
     51 		f := c.Flags().Lookup(v)
     52 		if f == nil {
     53 			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
     54 		}
     55 		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
     56 		if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
     57 			panic(err)
     58 		}
     59 	}
     60 }
     61 
     62 // ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
     63 // first error encountered.
     64 func (c *Command) ValidateFlagGroups() error {
     65 	if c.DisableFlagParsing {
     66 		return nil
     67 	}
     68 
     69 	flags := c.Flags()
     70 
     71 	// groupStatus format is the list of flags as a unique ID,
     72 	// then a map of each flag name and whether it is set or not.
     73 	groupStatus := map[string]map[string]bool{}
     74 	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
     75 	flags.VisitAll(func(pflag *flag.Flag) {
     76 		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
     77 		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
     78 	})
     79 
     80 	if err := validateRequiredFlagGroups(groupStatus); err != nil {
     81 		return err
     82 	}
     83 	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
     84 		return err
     85 	}
     86 	return nil
     87 }
     88 
     89 func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
     90 	for _, fname := range flagnames {
     91 		f := fs.Lookup(fname)
     92 		if f == nil {
     93 			return false
     94 		}
     95 	}
     96 	return true
     97 }
     98 
     99 func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
    100 	groupInfo, found := pflag.Annotations[annotation]
    101 	if found {
    102 		for _, group := range groupInfo {
    103 			if groupStatus[group] == nil {
    104 				flagnames := strings.Split(group, " ")
    105 
    106 				// Only consider this flag group at all if all the flags are defined.
    107 				if !hasAllFlags(flags, flagnames...) {
    108 					continue
    109 				}
    110 
    111 				groupStatus[group] = map[string]bool{}
    112 				for _, name := range flagnames {
    113 					groupStatus[group][name] = false
    114 				}
    115 			}
    116 
    117 			groupStatus[group][pflag.Name] = pflag.Changed
    118 		}
    119 	}
    120 }
    121 
    122 func validateRequiredFlagGroups(data map[string]map[string]bool) error {
    123 	keys := sortedKeys(data)
    124 	for _, flagList := range keys {
    125 		flagnameAndStatus := data[flagList]
    126 
    127 		unset := []string{}
    128 		for flagname, isSet := range flagnameAndStatus {
    129 			if !isSet {
    130 				unset = append(unset, flagname)
    131 			}
    132 		}
    133 		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
    134 			continue
    135 		}
    136 
    137 		// Sort values, so they can be tested/scripted against consistently.
    138 		sort.Strings(unset)
    139 		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
    140 	}
    141 
    142 	return nil
    143 }
    144 
    145 func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
    146 	keys := sortedKeys(data)
    147 	for _, flagList := range keys {
    148 		flagnameAndStatus := data[flagList]
    149 		var set []string
    150 		for flagname, isSet := range flagnameAndStatus {
    151 			if isSet {
    152 				set = append(set, flagname)
    153 			}
    154 		}
    155 		if len(set) == 0 || len(set) == 1 {
    156 			continue
    157 		}
    158 
    159 		// Sort values, so they can be tested/scripted against consistently.
    160 		sort.Strings(set)
    161 		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
    162 	}
    163 	return nil
    164 }
    165 
    166 func sortedKeys(m map[string]map[string]bool) []string {
    167 	keys := make([]string, len(m))
    168 	i := 0
    169 	for k := range m {
    170 		keys[i] = k
    171 		i++
    172 	}
    173 	sort.Strings(keys)
    174 	return keys
    175 }
    176 
    177 // enforceFlagGroupsForCompletion will do the following:
    178 // - when a flag in a group is present, other flags in the group will be marked required
    179 // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
    180 // This allows the standard completion logic to behave appropriately for flag groups
    181 func (c *Command) enforceFlagGroupsForCompletion() {
    182 	if c.DisableFlagParsing {
    183 		return
    184 	}
    185 
    186 	flags := c.Flags()
    187 	groupStatus := map[string]map[string]bool{}
    188 	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
    189 	c.Flags().VisitAll(func(pflag *flag.Flag) {
    190 		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
    191 		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
    192 	})
    193 
    194 	// If a flag that is part of a group is present, we make all the other flags
    195 	// of that group required so that the shell completion suggests them automatically
    196 	for flagList, flagnameAndStatus := range groupStatus {
    197 		for _, isSet := range flagnameAndStatus {
    198 			if isSet {
    199 				// One of the flags of the group is set, mark the other ones as required
    200 				for _, fName := range strings.Split(flagList, " ") {
    201 					_ = c.MarkFlagRequired(fName)
    202 				}
    203 			}
    204 		}
    205 	}
    206 
    207 	// If a flag that is mutually exclusive to others is present, we hide the other
    208 	// flags of that group so the shell completion does not suggest them
    209 	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
    210 		for flagName, isSet := range flagnameAndStatus {
    211 			if isSet {
    212 				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
    213 				// Don't mark the flag that is already set as hidden because it may be an
    214 				// array or slice flag and therefore must continue being suggested
    215 				for _, fName := range strings.Split(flagList, " ") {
    216 					if fName != flagName {
    217 						flag := c.Flags().Lookup(fName)
    218 						flag.Hidden = true
    219 					}
    220 				}
    221 			}
    222 		}
    223 	}
    224 }