diff --git a/cmd/yggdrasil/chuser_openbsd.go b/cmd/yggdrasil/chuser_openbsd.go deleted file mode 100644 index c1cccf7b..00000000 --- a/cmd/yggdrasil/chuser_openbsd.go +++ /dev/null @@ -1,57 +0,0 @@ -//go:build openbsd -// +build openbsd - -package main - -import ( - "fmt" - "os/user" - "strconv" - "strings" - - "golang.org/x/sys/unix" -) - -func chuser(input string) error { - givenUser, givenGroup, _ := strings.Cut(input, ":") - - var ( - err error - usr *user.User - grp *user.Group - uid, gid int - ) - - if usr, err = user.Lookup(givenUser); err != nil { - if usr, err = user.LookupId(givenUser); err != nil { - return err - } - } - if uid, err = strconv.Atoi(usr.Uid); err != nil { - return err - } - - if givenGroup != "" { - if grp, err = user.LookupGroup(givenGroup); err != nil { - if grp, err = user.LookupGroupId(givenGroup); err != nil { - return err - } - } - - gid, _ = strconv.Atoi(grp.Gid) - } else { - gid, _ = strconv.Atoi(usr.Gid) - } - - if err := unix.Setgroups([]int{gid}); err != nil { - return fmt.Errorf("setgroups: %d: %v", gid, err) - } - if err := unix.Setresgid(gid, gid, gid); err != nil { - return fmt.Errorf("setresgid: %d: %v", gid, err) - } - if err := unix.Setresuid(uid, uid, uid); err != nil { - return fmt.Errorf("setresuid: %d: %v", uid, err) - } - - return nil -} diff --git a/cmd/yggdrasil/chuser_unix.go b/cmd/yggdrasil/chuser_unix.go index 978d14b7..20e28356 100644 --- a/cmd/yggdrasil/chuser_unix.go +++ b/cmd/yggdrasil/chuser_unix.go @@ -1,92 +1,56 @@ -//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || solaris -// +build aix darwin dragonfly freebsd linux netbsd solaris +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris package main import ( - "errors" "fmt" - "math" - osuser "os/user" + "os/user" "strconv" "strings" - "syscall" + + "golang.org/x/sys/unix" ) -func chuser(user string) error { - group := "" - if i := strings.IndexByte(user, ':'); i >= 0 { - user, group = user[:i], user[i+1:] - } +func chuser(input string) error { + givenUser, givenGroup, _ := strings.Cut(input, ":") - u := (*osuser.User)(nil) - g := (*osuser.Group)(nil) + var ( + err error + usr *user.User + grp *user.Group + uid, gid int + ) - if user != "" { - if _, err := strconv.ParseUint(user, 10, 32); err == nil { - u, err = osuser.LookupId(user) - if err != nil { - return fmt.Errorf("failed to lookup user by id %q: %v", user, err) - } - } else { - u, err = osuser.Lookup(user) - if err != nil { - return fmt.Errorf("failed to lookup user by name %q: %v", user, err) - } + if usr, err = user.Lookup(givenUser); err != nil { + if usr, err = user.LookupId(givenUser); err != nil { + return err } } - if group != "" { - if _, err := strconv.ParseUint(group, 10, 32); err == nil { - g, err = osuser.LookupGroupId(group) - if err != nil { - return fmt.Errorf("failed to lookup group by id %q: %v", user, err) - } - } else { - g, err = osuser.LookupGroup(group) - if err != nil { - return fmt.Errorf("failed to lookup group by name %q: %v", user, err) - } - } + if uid, err = strconv.Atoi(usr.Uid); err != nil { + return err } - if g != nil { - gid, _ := strconv.ParseUint(g.Gid, 10, 32) - var err error - if gid < math.MaxInt { - if err := syscall.Setgroups([]int{int(gid)}); err != nil { - return fmt.Errorf("failed to setgroups %d: %v", gid, err) + if givenGroup != "" { + if grp, err = user.LookupGroup(givenGroup); err != nil { + if grp, err = user.LookupGroupId(givenGroup); err != nil { + return err } - err = syscall.Setgid(int(gid)) - } else { - err = errors.New("gid too big") } - if err != nil { - return fmt.Errorf("failed to setgid %d: %v", gid, err) - } - } else if u != nil { - gid, _ := strconv.ParseUint(u.Gid, 10, 32) - if err := syscall.Setgroups([]int{int(uint32(gid))}); err != nil { - return fmt.Errorf("failed to setgroups %d: %v", gid, err) - } - err := syscall.Setgid(int(uint32(gid))) - if err != nil { - return fmt.Errorf("failed to setgid %d: %v", gid, err) - } + gid, _ = strconv.Atoi(grp.Gid) + } else { + gid, _ = strconv.Atoi(usr.Gid) } - if u != nil { - uid, _ := strconv.ParseUint(u.Uid, 10, 32) - var err error - if uid < math.MaxInt { - err = syscall.Setuid(int(uid)) - } else { - err = errors.New("uid too big") - } - - if err != nil { - return fmt.Errorf("failed to setuid %d: %v", uid, err) - } + if err := unix.Setgroups([]int{gid}); err != nil { + return fmt.Errorf("setgroups: %d: %v", gid, err) + } + if err := unix.Setresgid(gid, gid, gid); err != nil { + return fmt.Errorf("setresgid: %d: %v", gid, err) + } + if err := unix.Setresuid(uid, uid, uid); err != nil { + return fmt.Errorf("setresuid: %d: %v", uid, err) } return nil