213
package flags
import (
"fmt"
"git.sophuwu.com/gophuwu/parsers"
"os"
)
type flag struct {
Name string
Short string
HelpMsg string
Type string
Default interface{}
Value interface{}
}
var flags = make(map[string]flag)
var shortFlags = make(map[byte]string)
// NewFlag creates a new flag with the given name, help message, and default value.
// It returns an error if the flag name is invalid, if the flag already exists,
// or if the default value is of an unsupported type.
// Supported types for defaultValue are: string, int, bool, and float64.
// short must be a single character and unique across all flags.
// set short to an empty string if no short flag is needed.
func NewFlag(name, short, helpMsg string, defaultValue interface{}) error {
if len(name) == 0 {
return fmt.Errorf("flag name cannot be empty")
}
if _, exists := flags[name]; exists {
return fmt.Errorf("flag %s already exists", name)
}
if defaultValue == nil {
return fmt.Errorf("default value for flag %s cannot be nil", name)
}
var f flag
f.Name = name
f.HelpMsg = helpMsg
if len(short) > 1 {
return fmt.Errorf("short flag must be a single character")
}
if len(short) == 1 {
shrt := short[0]
if _, exists := shortFlags[shrt]; exists {
return fmt.Errorf("short flag %s already exists", short)
}
shortFlags[shrt] = name
f.Short = short
}
switch defaultValue.(type) {
case string:
f.Type = "string"
break
case int:
f.Type = "int"
break
case bool:
f.Type = "bool"
break
case float64:
f.Type = "float64"
break
default:
return fmt.Errorf("unsupported flag type for %s", name)
}
f.Default = defaultValue
flags[name] = f
return nil
}
func getFlag(name, t string) (interface{}, error) {
f, exists := flags[name]
if !exists {
return nil, fmt.Errorf("flag %s does not exist", name)
}
if f.Type != t {
return nil, fmt.Errorf("flag %s is not of type bool", name)
}
if f.Value == nil {
return f.Default, nil
}
return f.Value, nil
}
func GetBoolFlag(name string) (bool, error) {
i, err := getFlag(name, "bool")
if err != nil {
return false, err
}
return i.(bool), nil
}
func GetIntFlag(name string) (int, error) {
i, err := getFlag(name, "int")
if err != nil {
return 0, err
}
return i.(int), nil
}
func GetStringFlag(name string) (string, error) {
i, err := getFlag(name, "string")
if err != nil {
return "", err
}
return i.(string), nil
}
func GetFloat64Flag(name string) (float64, error) {
i, err := getFlag(name, "float64")
if err != nil {
return 0, err
}
return i.(float64), nil
}
func ParseArgs() error {
if len(os.Args) < 2 {
return nil
}
var v string
var vv byte
var i, j int
var f flag
var ok bool
var err error
shortFlags['h'] = "help"
var args []string
for i = 1; i < len(os.Args); i++ {
v = os.Args[i]
if (len(v) > 2 && v[0] == '-' && v[1] != '-') || (len(v) > 1 && v[0] == '-') {
for j = 1; j < len(os.Args[i]); j++ {
vv = os.Args[i][j]
v, ok = shortFlags[vv]
if !ok {
return fmt.Errorf("unknown short flag: %c", vv)
}
args = append(args, "--"+v)
}
continue
}
args = append(args, v)
}
for i = 0; i < len(args); i++ {
v = args[i]
if len(v) > 2 && v[0] == '-' && v[1] == '-' {
v = v[2:]
if v == "help" {
PrintHelp()
os.Exit(0)
}
f, ok = flags[v]
if !ok {
return fmt.Errorf("unknown flag: %s", v)
}
if f.Type == "bool" {
f.Value = !f.Default.(bool)
flags[f.Name] = f
continue
}
i++
v = args[i]
if i >= len(args) {
return fmt.Errorf("flag %s requires a value", v)
}
switch f.Type {
case "string":
f.Value = v
break
case "int":
f.Value, err = parsers.ParseInt(v)
break
case "float64":
f.Value, err = parsers.ParseFloat(v)
break
default:
return fmt.Errorf("unsupported flag type for %s", f.Name)
}
if err != nil {
return fmt.Errorf("error parsing flag %s: %v", f.Name, err)
}
flags[f.Name] = f
}
}
return nil
}
func PrintHelp() {
fmt.Printf("Usage: %s [options]\n", os.Args[0])
fmt.Println("Options:")
fmt.Println(" -h --help\n\tShow this help message")
for _, f := range flags {
fmt.Printf(" ")
if len(f.Short) == 1 {
fmt.Printf("-%s, ", f.Short)
} else {
fmt.Printf(" ")
}
fmt.Printf("--%s ", f.Name)
if f.Type == "bool" {
fmt.Printf("\n")
} else {
fmt.Printf("<%s>\n", f.Type)
}
fmt.Printf("\t%s ", f.HelpMsg)
if f.Default != nil {
fmt.Printf("(default: %v)\n", f.Default)
} else {
fmt.Printf("\n")
}
}
}