package cmd import ( "errors" "fmt" "regexp" "sort" "strconv" "strings" ) var ( ErrArgumentDefinition = errors.New("invalid createArgument definition") ErrOptionDefinition = errors.New("invalid createOption definition") ErrVersionDefinition = errors.New("invalid version definition") optionArgumentSplitRE *regexp.Regexp = regexp.MustCompile(`[, ] *`) tokenRE *regexp.Regexp = regexp.MustCompile("^[a-z]\\w*$") ) type Option struct { Name string // 参数名称 Type string // 数据类型 Short string // 段名称,e.g. -w Long string // 长名称,e.g. --watch Description string // 简短描述 LongDescription string // 详细描述 Variadic bool // 是否支持多个值 Required bool // 是否必须 DefaultValue any // 默认值 RuntimeValue []string // 运行时的值 } func (o *Option) toArgument() *Argument { return &Argument{ Name: o.Name, Type: o.Type, Description: o.Description, LongDescription: o.LongDescription, Variadic: o.Variadic, Required: o.Required, DefaultValue: o.DefaultValue, } } type Argument struct { Name string // 参数名称 Type string // 数据类型 Description string // 简短描述 LongDescription string // 详细描述 Variadic bool // 是否支持多个值 Required bool // 是否必须 DefaultValue any // 默认值 } // createArgument 解析参数 // // // [Name]、[Name:string]、[Name...:string] func createArgument(expr string) (*Argument, error) { if len(expr) == 0 { return nil, ErrArgumentDefinition } var a Argument if strings.HasPrefix(expr, "[") { if !strings.HasSuffix(expr, "]") { return nil, ErrArgumentDefinition } a.Required = true } else if strings.HasPrefix(expr, "<") { if !strings.HasSuffix(expr, ">") { return nil, ErrArgumentDefinition } a.Required = false } else { return nil, ErrArgumentDefinition } // 移除 [] <> expr = expr[1 : len(expr)-1] // 没有参数名称 if len(expr) == 0 { return nil, ErrArgumentDefinition } parts := strings.Split(expr, ":") switch len(parts) { case 1: a.Name = parts[0] case 2: a.Name = parts[0] a.Type = parts[1] // 支持多个值 if strings.HasPrefix(a.Type, "...") { a.Type = a.Type[3:] a.Variadic = true } default: return nil, ErrArgumentDefinition } // 验证名称是否合法 if !tokenRE.MatchString(a.Name) { return nil, ErrArgumentDefinition } // 类型名称错误 if len(a.Type) > 0 && !tokenRE.MatchString(a.Type) { fmt.Println(a.Type) return nil, ErrArgumentDefinition } return &a, nil } func createOption(expr string) (*Option, error) { if len(expr) == 0 { return nil, ErrArgumentDefinition } var o Option shortIndex := -1 longIndex := -1 nameIndex := -1 parts := optionArgumentSplitRE.Split(expr, -1) for i, part := range parts { barLen := computeBarLen(part) // 解析的时参数 if barLen == 0 { // 存在多个参数 if nameIndex > -1 { return nil, ErrOptionDefinition } nameIndex = i a, err := createArgument(part) // 参数格式错误 if errors.Is(err, ErrArgumentDefinition) { return nil, ErrOptionDefinition } o.Name = a.Name o.Type = a.Type o.Variadic = a.Variadic o.Required = a.Required } else if barLen == 1 { // 短名称重复 if shortIndex > -1 { return nil, ErrOptionDefinition } shortIndex = i o.Short = part[barLen:] } else if barLen == 2 { if longIndex > -1 { return nil, ErrOptionDefinition } longIndex = i o.Long = part[barLen:] } else { // 3个或超过3个横杠 return nil, ErrOptionDefinition } // 严格检查顺序 var positionError bool if nameIndex > -1 { positionError = nameIndex < longIndex || nameIndex < shortIndex } else if longIndex > -1 { positionError = longIndex < shortIndex } if positionError { return nil, ErrOptionDefinition } } // 验证短名称是否合法 if shortIndex > -1 && (len(o.Short) != 1 || !tokenRE.MatchString(o.Short)) { return nil, ErrOptionDefinition } // 未解析到长名称或者长名称不合法 if longIndex == -1 || !tokenRE.MatchString(o.Long) { return nil, ErrOptionDefinition } // 未指定名称,则使用 long flag if nameIndex == -1 { o.Name = o.Long } return &o, nil } func computeBarLen(str string) int { for i, s := range str { if s != '-' { return i } } return 0 } type TypeParser func(...string) any func parseValue(cmd *Command, typ string, values []string, defaultValue any, variadic bool) any { if parser, ok := getTypeParser(cmd, typ); ok { return parser(values...) } else if l := len(values); l > 0 { switch typ { case "bool": if l == 0 { if variadic { return []bool{} } return true } var parsed []bool for _, value := range values { v, e := strconv.ParseBool(value) if e != nil { panic(errors.New("invalid value of type " + typ)) } if variadic { parsed = append(parsed, v) } else if l == 1 { return v } else { panic(errors.New("found multi values")) } } return parsed case "string": if variadic { return values[:] } else if l == 0 { return "" } else if l == 1 { return values[0] } else { panic(errors.New("found multi values")) } case "float": if l == 0 { if variadic { return []float64{} } return float64(0) } var parsed []float64 for _, value := range values { v, e := strconv.ParseFloat(value, 64) if e != nil { panic(errors.New("invalid value of type " + typ)) } if variadic { parsed = append(parsed, v) } else if l == 1 { return v } else { panic(errors.New("found multi values")) } } return parsed case "int": if l == 0 { if variadic { return []int64{} } return int64(0) } var parsed []int64 for _, value := range values { v, e := strconv.ParseInt(value, 10, 64) if e != nil { panic(errors.New("invalid value of type " + typ)) } if variadic { parsed = append(parsed, v) } else if l == 1 { return v } else { panic(errors.New("found multi values")) } } return parsed case "uint": if l == 0 { if variadic { return []uint64{} } return int64(0) } var parsed []uint64 for _, value := range values { v, e := strconv.ParseUint(value, 10, 64) if e != nil { panic(errors.New("invalid value of type " + typ)) } if variadic { parsed = append(parsed, v) } else if l == 1 { return v } else { panic(errors.New("found multi values")) } } return parsed default: if variadic { return values[:] } else if l == 0 { return nil } else if l == 1 { return values[0] } else { panic(errors.New("found multi values")) } } } else if defaultValue != nil { return defaultValue } else { return nil } } func getTypeParser(cmd *Command, typ string) (TypeParser, bool) { if cmd.types != nil { if parser, ok := cmd.types[typ]; ok { return parser, true } } if cmd.globalTypes != nil { if parser, ok := cmd.globalTypes[typ]; ok { return parser, true } } if cmd.program != nil && cmd.program.globalTypes != nil { parser, ok := cmd.program.globalTypes[typ] return parser, ok } return nil, false } func parseArgs(args []string, cmd *Command) (help bool, version bool, options map[string]any, arguments []any) { var removes []int var foundOption *Option var variadicValues []string // 可变参数值 var variadicOption *Option // 可变参数选项 variadicStart := -1 // 可变参数开始位置 variadicStop := -1 // 可变参数结束位置(不包含) foundPosition := -1 findOption := func(key string, isLong, isShort bool) *Option { for _, o := range cmd.options { if (isLong && o.Long == key) || (isShort && o.Short == key) { return o } } return nil } saveParsedVariadicValues := func() { if variadicOption == nil { return } for i := variadicStart; i < variadicStop; i++ { removes = append(removes, i) } variadicOption.RuntimeValue = variadicValues[:] variadicOption = nil variadicValues = []string{} variadicStart = -1 variadicStop = -1 } saveFound := func(value string) { if foundOption == nil { return } if len(value) > 0 { foundOption.RuntimeValue = []string{value} } removes = append(removes, foundPosition, foundPosition+1) // foundPosition+1 可能会越界 foundOption = nil foundPosition = -1 } for i := 0; i < len(args); i++ { v := args[i] barLen := computeBarLen(v) var isShort bool var isLong bool switch barLen { case 0: if variadicOption != nil { variadicValues = append(variadicValues, v) } else /*if foundOption != nil*/ { saveFound(v) } continue case 1: isShort = true v = v[1:] case 2: isLong = true v = v[2:] default: panic("invalid argument") } // 保存之前找到的参数 variadicStop = i saveParsedVariadicValues() saveFound("") // 解析参数 var val string var key string var equalsSign bool if j := strings.IndexByte(v, '='); j > -1 { equalsSign = true key = v[:j] if j < len(v)-1 { val = v[j+1:] } } else { key = v } // 是不是帮助命令 if cmd.help != nil && ((isShort && cmd.help.Short == key) || (isLong && cmd.help.Long == key)) { help = true removes = append(removes, i) continue } // 是不是版本命令 if cmd.version != nil && ((isShort && cmd.version.Short == key) || (isLong && cmd.version.Long == key)) { version = true removes = append(removes, i) continue } opt := findOption(key, isLong, isShort) if opt == nil { // TODO show help information panic("unsupported flag \"" + v + "\"") } foundOption = opt foundPosition = i // 后面的不解析 if equalsSign { if opt.Variadic { // TODO show help information panic("invalid values for flag \"" + v + "\"") } saveFound(val) continue } if opt.Variadic { foundOption = nil foundPosition = -1 variadicOption = opt variadicStart = i } } if variadicOption != nil { variadicStop = len(args) saveParsedVariadicValues() } else { saveFound("") } // 移除上面解析 flags 标记的参数 temp := args[:] sort.Sort(sort.Reverse(sort.IntSlice(removes))) for _, remove := range removes { l := len(temp) if remove >= l { continue } if l == 0 { break } if remove != len(temp)-1 { temp = append(temp[:remove], temp[remove+1:]...) } else { temp = temp[:remove] } } // 参数赋值 arguments = make([]any, len(cmd.arguments)) l := len(temp) - 1 for i, arg := range cmd.arguments { // 可变参数 if arg.Variadic { if i <= l { arguments[i] = parseValue(cmd, arg.Type, temp[i:], arg.DefaultValue, true) } else if arg.DefaultValue == nil { if arg.Required { panic(errors.New("missing argument with " + arg.Name)) } arguments[i] = []any{} } else if list, ok := arg.DefaultValue.([]any); ok { arguments[i] = list } else { panic(errors.New("invalid default value, except a(n) `[]any`")) } break } // 找到输入的参数 if i <= l { arguments[i] = parseValue(cmd, arg.Type, []string{temp[i]}, arg.DefaultValue, false) } else if arg.DefaultValue != nil { // 启用默认值 arguments[i] = arg.DefaultValue } else if arg.Required { panic(errors.New("missing argument with " + arg.Name)) } else { arguments[i] = nil } } // 解析出需要的可选参数 options = make(map[string]any) for _, o := range cmd.options { options[o.Name] = parseValue(cmd, o.Type, o.RuntimeValue, o.DefaultValue, o.Variadic) o.RuntimeValue = nil } return }