diff --git a/src/tun/tun_windows.go b/src/tun/tun_windows.go index c3e36596..9aa65f74 100644 --- a/src/tun/tun_windows.go +++ b/src/tun/tun_windows.go @@ -4,10 +4,9 @@ package tun import ( - "bytes" "errors" "log" - "net" + "net/netip" "github.com/yggdrasil-network/yggdrasil-go/src/defaults" "golang.org/x/sys/windows" @@ -83,22 +82,18 @@ func (tun *TunAdapter) setupAddress(addr string) error { return errors.New("Can't configure IPv6 address as TUN adapter is not present") } if intf, ok := tun.iface.(*wgtun.NativeTun); ok { - if ipaddr, ipnet, err := net.ParseCIDR(addr); err == nil { - luid := winipcfg.LUID(intf.LUID()) - addresses := append([]net.IPNet{}, net.IPNet{ - IP: ipaddr, - Mask: ipnet.Mask, - }) - - err := luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) - if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET6, addresses) - err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) - } - if err != nil { - return err - } - } else { + prefix, err := netip.ParsePrefix(addr) + if err != nil { + return err + } + luid := winipcfg.LUID(intf.LUID()) + addresses := []netip.Prefix{prefix} + err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET6, addresses) + err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) + } + if err != nil { return err } } else { @@ -112,24 +107,13 @@ func (tun *TunAdapter) setupAddress(addr string) error { * SPDX-License-Identifier: MIT * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. */ -func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) { if len(addresses) == 0 { return } - includedInAddresses := func(a net.IPNet) bool { - // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! - for _, addr := range addresses { - ip := addr.IP - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } - mA, _ := addr.Mask.Size() - mB, _ := a.Mask.Size() - if bytes.Equal(ip, a.IP) && mA == mB { - return true - } - } - return false + addrHash := make(map[netip.Addr]bool, len(addresses)) + for i := range addresses { + addrHash[addresses[i].Addr()] = true } interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { @@ -140,11 +124,10 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add continue } for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - if includedInAddresses(ipnet) { - log.Printf("Cleaning up stale address %s from interface ā€˜%s’", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) + if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] { + prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) + log.Printf("Cleaning up stale address %s from interface %q", prefix.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(prefix) } } }