one chuser() to rule them all

This commit is contained in:
Klemens Nanni 2024-11-11 23:14:29 +03:00
parent a0bfd9da44
commit 3fede90ae1
No known key found for this signature in database
2 changed files with 33 additions and 126 deletions

View file

@ -1,57 +0,0 @@
//go:build openbsd
// +build openbsd
package main
import (
"fmt"
"os/user"
"strconv"
"strings"
"golang.org/x/sys/unix"
)
func chuser(input string) error {
givenUser, givenGroup, _ := strings.Cut(input, ":")
var (
err error
usr *user.User
grp *user.Group
uid, gid int
)
if usr, err = user.Lookup(givenUser); err != nil {
if usr, err = user.LookupId(givenUser); err != nil {
return err
}
}
if uid, err = strconv.Atoi(usr.Uid); err != nil {
return err
}
if givenGroup != "" {
if grp, err = user.LookupGroup(givenGroup); err != nil {
if grp, err = user.LookupGroupId(givenGroup); err != nil {
return err
}
}
gid, _ = strconv.Atoi(grp.Gid)
} else {
gid, _ = strconv.Atoi(usr.Gid)
}
if err := unix.Setgroups([]int{gid}); err != nil {
return fmt.Errorf("setgroups: %d: %v", gid, err)
}
if err := unix.Setresgid(gid, gid, gid); err != nil {
return fmt.Errorf("setresgid: %d: %v", gid, err)
}
if err := unix.Setresuid(uid, uid, uid); err != nil {
return fmt.Errorf("setresuid: %d: %v", uid, err)
}
return nil
}

View file

@ -1,92 +1,56 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || solaris //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package main package main
import ( import (
"errors"
"fmt" "fmt"
"math" "os/user"
osuser "os/user"
"strconv" "strconv"
"strings" "strings"
"syscall"
"golang.org/x/sys/unix"
) )
func chuser(user string) error { func chuser(input string) error {
group := "" givenUser, givenGroup, _ := strings.Cut(input, ":")
if i := strings.IndexByte(user, ':'); i >= 0 {
user, group = user[:i], user[i+1:] var (
err error
usr *user.User
grp *user.Group
uid, gid int
)
if usr, err = user.Lookup(givenUser); err != nil {
if usr, err = user.LookupId(givenUser); err != nil {
return err
}
}
if uid, err = strconv.Atoi(usr.Uid); err != nil {
return err
} }
u := (*osuser.User)(nil) if givenGroup != "" {
g := (*osuser.Group)(nil) if grp, err = user.LookupGroup(givenGroup); err != nil {
if grp, err = user.LookupGroupId(givenGroup); err != nil {
if user != "" { return err
if _, err := strconv.ParseUint(user, 10, 32); err == nil {
u, err = osuser.LookupId(user)
if err != nil {
return fmt.Errorf("failed to lookup user by id %q: %v", user, err)
} }
}
gid, _ = strconv.Atoi(grp.Gid)
} else { } else {
u, err = osuser.Lookup(user) gid, _ = strconv.Atoi(usr.Gid)
if err != nil {
return fmt.Errorf("failed to lookup user by name %q: %v", user, err)
}
}
}
if group != "" {
if _, err := strconv.ParseUint(group, 10, 32); err == nil {
g, err = osuser.LookupGroupId(group)
if err != nil {
return fmt.Errorf("failed to lookup group by id %q: %v", user, err)
}
} else {
g, err = osuser.LookupGroup(group)
if err != nil {
return fmt.Errorf("failed to lookup group by name %q: %v", user, err)
}
}
} }
if g != nil { if err := unix.Setgroups([]int{gid}); err != nil {
gid, _ := strconv.ParseUint(g.Gid, 10, 32) return fmt.Errorf("setgroups: %d: %v", gid, err)
var err error
if gid < math.MaxInt {
if err := syscall.Setgroups([]int{int(gid)}); err != nil {
return fmt.Errorf("failed to setgroups %d: %v", gid, err)
} }
err = syscall.Setgid(int(gid)) if err := unix.Setresgid(gid, gid, gid); err != nil {
} else { return fmt.Errorf("setresgid: %d: %v", gid, err)
err = errors.New("gid too big")
}
if err != nil {
return fmt.Errorf("failed to setgid %d: %v", gid, err)
}
} else if u != nil {
gid, _ := strconv.ParseUint(u.Gid, 10, 32)
if err := syscall.Setgroups([]int{int(uint32(gid))}); err != nil {
return fmt.Errorf("failed to setgroups %d: %v", gid, err)
}
err := syscall.Setgid(int(uint32(gid)))
if err != nil {
return fmt.Errorf("failed to setgid %d: %v", gid, err)
}
}
if u != nil {
uid, _ := strconv.ParseUint(u.Uid, 10, 32)
var err error
if uid < math.MaxInt {
err = syscall.Setuid(int(uid))
} else {
err = errors.New("uid too big")
}
if err != nil {
return fmt.Errorf("failed to setuid %d: %v", uid, err)
} }
if err := unix.Setresuid(uid, uid, uid); err != nil {
return fmt.Errorf("setresuid: %d: %v", uid, err)
} }
return nil return nil