diff --git a/syscalls_windows.go b/syscalls_windows.go index 81ab6ad..5ecea9a 100644 --- a/syscalls_windows.go +++ b/syscalls_windows.go @@ -11,8 +11,9 @@ import ( "errors" "fmt" "net" - "os" + "sync" "syscall" + "unsafe" "golang.org/x/sys/windows/registry" ) @@ -35,8 +36,104 @@ var ( // Driver maker specified ComponentId // ComponentId is defined here: https://github.com/OpenVPN/tap-windows6/blob/master/version.m4#L5 componentId = "tap0901" + nCreateEvent, + nResetEvent, + nGetOverlappedResult uintptr ) +func init() { + k32, err := syscall.LoadLibrary("kernel32.dll") + if err != nil { + panic("LoadLibrary " + err.Error()) + } + defer syscall.FreeLibrary(k32) + + nCreateEvent = getProcAddr(k32, "CreateEventW") + nResetEvent = getProcAddr(k32, "ResetEvent") + nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult") +} + +func getProcAddr(lib syscall.Handle, name string) uintptr { + addr, err := syscall.GetProcAddress(lib, name) + if err != nil { + panic(name + " " + err.Error()) + } + return addr +} + +func resetEvent(h syscall.Handle) error { + r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0) + if r == 0 { + return err + } + return nil +} + +func getOverlappedResult(h syscall.Handle, overlapped *syscall.Overlapped) (int, error) { + var n int + r, _, err := syscall.Syscall6(nGetOverlappedResult, 4, + uintptr(h), + uintptr(unsafe.Pointer(overlapped)), + uintptr(unsafe.Pointer(&n)), 1, 0, 0) + if r == 0 { + return n, err + } + + return n, nil +} + +func newOverlapped() (*syscall.Overlapped, error) { + var overlapped syscall.Overlapped + r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0) + if r == 0 { + return nil, err + } + overlapped.HEvent = syscall.Handle(r) + return &overlapped, nil +} + +type wfile struct { + fd syscall.Handle + rl sync.Mutex + wl sync.Mutex + ro *syscall.Overlapped + wo *syscall.Overlapped +} + +func (f *wfile) Close() error { + return syscall.Close(f.fd) +} + +func (f *wfile) Write(b []byte) (int, error) { + f.wl.Lock() + defer f.wl.Unlock() + + if err := resetEvent(f.wo.HEvent); err != nil { + return 0, err + } + var n uint32 + err := syscall.WriteFile(f.fd, b, &n, f.wo) + if err != nil && err != syscall.ERROR_IO_PENDING { + return int(n), err + } + return getOverlappedResult(f.fd, f.wo) +} + +func (f *wfile) Read(b []byte) (int, error) { + f.rl.Lock() + defer f.rl.Unlock() + + if err := resetEvent(f.ro.HEvent); err != nil { + return 0, err + } + var done uint32 + err := syscall.ReadFile(f.fd, b, &done, f.ro) + if err != nil && err != syscall.ERROR_IO_PENDING { + return int(done), err + } + return getOverlappedResult(f.fd, f.ro) +} + func ctl_code(device_type, function, method, access uint32) uint32 { return (device_type << 16) | (access << 14) | (function << 2) | method } @@ -98,7 +195,7 @@ func openDev(isTAP bool) (ifce *Interface, err error) { return nil, err } // type Handle uintptr - file, err := syscall.CreateFile(pathp, syscall.GENERIC_READ|syscall.GENERIC_WRITE, uint32(syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE), nil, syscall.OPEN_EXISTING, syscall.FILE_ATTRIBUTE_SYSTEM, 0) + file, err := syscall.CreateFile(pathp, syscall.GENERIC_READ|syscall.GENERIC_WRITE, uint32(syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE), nil, syscall.OPEN_EXISTING, syscall.FILE_ATTRIBUTE_SYSTEM|syscall.FILE_FLAG_OVERLAPPED, 0) // if err hanppens, close the interface. defer func() { if err != nil { @@ -121,7 +218,16 @@ func openDev(isTAP bool) (ifce *Interface, err error) { return nil, err } - fd := os.NewFile(uintptr(file), path) + // fd := os.NewFile(uintptr(file), path) + ro, err := newOverlapped() + if err != nil { + return + } + wo, err := newOverlapped() + if err != nil { + return + } + fd := &wfile{fd: file, ro: ro, wo: wo} ifce = &Interface{isTAP: isTAP, ReadWriteCloser: fd} // bring up device.