From ca6df8928973885eb1bb095412679759fd109ed3 Mon Sep 17 00:00:00 2001 From: HodlOnToYourButts Date: Mon, 25 Aug 2025 14:46:28 -0700 Subject: [PATCH] Use custom DNS interception for locally hosted dns resolver with non-standard port MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Uses RFC 2544 test network address (198.18.0.1) as dummy DNS server - Create route to 198.18.0.1 and set it as DNS server - Intercept DNS queries to 198.18.0.1 and forward to 127.0.0.1:5353 - Inject DNS responses back into VPN tunnel as packets from 198.18.0.1:53 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../yggdrasil/PacketTunnelProvider.kt | 233 +++++++++++++++++- 1 file changed, 226 insertions(+), 7 deletions(-) diff --git a/app/src/main/java/eu/neilalexander/yggdrasil/PacketTunnelProvider.kt b/app/src/main/java/eu/neilalexander/yggdrasil/PacketTunnelProvider.kt index aa84803..02af06b 100644 --- a/app/src/main/java/eu/neilalexander/yggdrasil/PacketTunnelProvider.kt +++ b/app/src/main/java/eu/neilalexander/yggdrasil/PacketTunnelProvider.kt @@ -12,8 +12,13 @@ import androidx.preference.PreferenceManager import eu.neilalexander.yggdrasil.YggStateReceiver.Companion.YGG_STATE_INTENT import mobile.Yggdrasil import org.json.JSONArray +import java.io.FileDescriptor import java.io.FileInputStream import java.io.FileOutputStream +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetAddress +import java.net.Inet6Address import java.util.concurrent.atomic.AtomicBoolean import kotlin.concurrent.thread @@ -35,6 +40,7 @@ open class PacketTunnelProvider: VpnService() { private var started = AtomicBoolean() private lateinit var config: ConfigurationProxy + private var customDnsPort: Int = 0 private var readerThread: Thread? = null private var writerThread: Thread? = null @@ -114,6 +120,26 @@ open class PacketTunnelProvider: VpnService() { yggdrasil.startJSON(config.getJSONByteArray()) val address = yggdrasil.addressString + var hasCustomDns = false + + val preferences = PreferenceManager.getDefaultSharedPreferences(this.baseContext) + val serverString = preferences.getString(KEY_DNS_SERVERS, "") + + // First, check if we have custom DNS servers to determine bypass behavior + if (serverString!!.isNotEmpty()) { + val servers = serverString.split(",") + servers.forEach { server -> + val trimmedServer = server.trim() + if (trimmedServer.startsWith("127.0.0.1:")) { + hasCustomDns = true + // Extract port number from 127.0.0.1:5353 format + val portString = trimmedServer.substring(10) // Remove "127.0.0.1:" + customDnsPort = portString.toIntOrNull() ?: 5353 + Log.i(TAG, "Found custom IPv4 DNS server: $trimmedServer, port: $customDnsPort") + } + } + } + val builder = Builder() .addAddress(address, 7) .addRoute("200::", 7) @@ -125,7 +151,18 @@ open class PacketTunnelProvider: VpnService() { // and we can't use DNS with Yggdrasil addresses. .addRoute("2000::", 128) .allowFamily(OsConstants.AF_INET) - .allowBypass() + + // Only allow bypass if no custom DNS servers + if (!hasCustomDns) { + builder.allowBypass() + Log.d(TAG, "Allowing VPN bypass - no custom DNS") + } else { + Log.i(TAG, "Not allowing VPN bypass - forcing DNS through VPN") + // Add route only for our dummy DNS server + builder.addRoute("198.18.0.1", 32) // Private IPv4 DNS (dummy) + } + + builder .setBlocking(true) .setMtu(yggdrasil.mtu.toInt()) .setSession("Yggdrasil") @@ -137,14 +174,20 @@ open class PacketTunnelProvider: VpnService() { builder.setMetered(false) } - val preferences = PreferenceManager.getDefaultSharedPreferences(this.baseContext) - val serverString = preferences.getString(KEY_DNS_SERVERS, "") - if (serverString!!.isNotEmpty()) { + // Now add the actual DNS servers + if (serverString.isNotEmpty()) { val servers = serverString.split(",") if (servers.isNotEmpty()) { - servers.forEach { - Log.i(TAG, "Using DNS server $it") - builder.addDnsServer(it) + servers.forEach { server -> + val trimmedServer = server.trim() + if (trimmedServer.startsWith("127.0.0.1:")) { + // Add only private DNS as dummy server to intercept + builder.addDnsServer("198.18.0.1") // Private IPv4 DNS (dummy) + Log.i(TAG, "Added dummy DNS server 198.18.0.1 to intercept for 127.0.0.1:$customDnsPort") + } else { + Log.i(TAG, "Using standard DNS server $trimmedServer") + builder.addDnsServer(trimmedServer) + } } } } @@ -333,6 +376,37 @@ open class PacketTunnelProvider: VpnService() { } try { val n = readerStream.read(b) + + if (n > 0) { + if (n > 20) { + val version = (b[0].toInt() and 0xF0) shr 4 + + if (version == 4 && n >= 20) { + val protocol = b[9].toInt() and 0xFF + val srcIP = String.format("%d.%d.%d.%d", + b[12].toInt() and 0xFF, b[13].toInt() and 0xFF, + b[14].toInt() and 0xFF, b[15].toInt() and 0xFF) + val dstIP = String.format("%d.%d.%d.%d", + b[16].toInt() and 0xFF, b[17].toInt() and 0xFF, + b[18].toInt() and 0xFF, b[19].toInt() and 0xFF) + + if (protocol == 17 && n >= 28) { // UDP + val ipHeaderLength = (b[0].toInt() and 0x0F) * 4 + val srcPort = ((b[ipHeaderLength].toInt() and 0xFF) shl 8) or + (b[ipHeaderLength + 1].toInt() and 0xFF) + val destPort = ((b[ipHeaderLength + 2].toInt() and 0xFF) shl 8) or + (b[ipHeaderLength + 3].toInt() and 0xFF) + + if (destPort == 53 && customDnsPort > 0) { + // Forward DNS query to custom server and inject response + forwardDnsQuery(b, n, ipHeaderLength + 8, true, srcIP, srcPort) + continue@reads // Skip normal processing + } + } + } + } + } + yggdrasil.sendBuffer(b, n.toLong()) } catch (e: Exception) { Log.i(TAG, "Error in sendBuffer: $e") @@ -344,4 +418,149 @@ open class PacketTunnelProvider: VpnService() { readerStream = null } } + + private fun forwardDnsQuery(packet: ByteArray, packetLength: Int, dnsPayloadOffset: Int, isIPv4: Boolean, srcIP: String, srcPort: Int) { + try { + // Extract DNS payload + val dnsPayloadLength = packetLength - dnsPayloadOffset + val dnsPayload = ByteArray(dnsPayloadLength) + System.arraycopy(packet, dnsPayloadOffset, dnsPayload, 0, dnsPayloadLength) + + + // Forward to custom DNS server at 127.0.0.1:customDnsPort + thread { + var socket: DatagramSocket? = null + try { + + // Create socket with no specific binding - let system choose any available port on any interface + socket = DatagramSocket() + socket.soTimeout = 1000 // 1 second timeout + + + val address = InetAddress.getByName("127.0.0.1") + val outPacket = DatagramPacket(dnsPayload, dnsPayload.size, address, customDnsPort) + + socket.send(outPacket) + + // Wait for response + val responseBuffer = ByteArray(1024) + val responsePacket = DatagramPacket(responseBuffer, responseBuffer.size) + + socket.receive(responsePacket) + + val responseData = ByteArray(responsePacket.length) + System.arraycopy(responseBuffer, 0, responseData, 0, responsePacket.length) + + + // Inject response back into VPN tunnel (IPv4 only) + injectDnsResponse(responseData, true, srcIP, srcPort) + + } catch (e: java.net.SocketTimeoutException) { + Log.e(TAG, "Timeout waiting for DNS response from 127.0.0.1:$customDnsPort") + } catch (e: Exception) { + Log.e(TAG, "Error forwarding DNS query: $e") + } finally { + socket?.close() + } + } + } catch (e: Exception) { + Log.e(TAG, "Error in forwardDnsQuery: $e") + } + } + + private fun injectDnsResponse(dnsResponse: ByteArray, isIPv4: Boolean, originalSrcIP: String, originalSrcPort: Int) { + try { + val writerStream = writerStream ?: return + + + // Create IPv4 UDP packet with DNS response from 198.18.0.1:53 back to client + val responsePacket = createIPv4UdpPacket(dnsResponse, "198.18.0.1", 53, originalSrcIP, originalSrcPort) + if (responsePacket.isNotEmpty()) { + writerStream.write(responsePacket) + } else { + Log.e(TAG, "Failed to create response packet") + } + } catch (e: Exception) { + Log.e(TAG, "Error injecting DNS response: $e") + } + } + + private fun createIPv4UdpPacket(payload: ByteArray, srcIP: String, srcPort: Int, dstIP: String, dstPort: Int): ByteArray { + val totalLength = 20 + 8 + payload.size // IP header + UDP header + payload + val packet = ByteArray(totalLength) + + // IPv4 Header (20 bytes) + packet[0] = 0x45 // Version 4, Header length 5 (20 bytes) + packet[1] = 0x00 // Type of Service + packet[2] = (totalLength shr 8).toByte() // Total length high byte + packet[3] = (totalLength and 0xFF).toByte() // Total length low byte + packet[4] = 0x00 // Identification high byte + packet[5] = 0x00 // Identification low byte + packet[6] = 0x40 // Flags: Don't fragment + packet[7] = 0x00 // Fragment offset + packet[8] = 64 // TTL + packet[9] = 17 // Protocol: UDP + packet[10] = 0x00 // Header checksum (will calculate) + packet[11] = 0x00 // Header checksum + + // Source IP + val srcBytes = srcIP.split(".").map { it.toInt().toByte() } + packet[12] = srcBytes[0] + packet[13] = srcBytes[1] + packet[14] = srcBytes[2] + packet[15] = srcBytes[3] + + // Destination IP + val dstBytes = dstIP.split(".").map { it.toInt().toByte() } + packet[16] = dstBytes[0] + packet[17] = dstBytes[1] + packet[18] = dstBytes[2] + packet[19] = dstBytes[3] + + // Calculate IP header checksum + val checksum = calculateIPv4Checksum(packet, 0, 20) + packet[10] = (checksum shr 8).toByte() + packet[11] = (checksum and 0xFF).toByte() + + // UDP Header (8 bytes) + val udpLength = 8 + payload.size + packet[20] = (srcPort shr 8).toByte() // Source port high byte + packet[21] = (srcPort and 0xFF).toByte() // Source port low byte + packet[22] = (dstPort shr 8).toByte() // Dest port high byte + packet[23] = (dstPort and 0xFF).toByte() // Dest port low byte + packet[24] = (udpLength shr 8).toByte() // UDP length high byte + packet[25] = (udpLength and 0xFF).toByte() // UDP length low byte + packet[26] = 0x00 // UDP checksum (optional for IPv4) + packet[27] = 0x00 // UDP checksum + + // Copy payload + System.arraycopy(payload, 0, packet, 28, payload.size) + + return packet + } + + + private fun calculateIPv4Checksum(data: ByteArray, offset: Int, length: Int): Int { + var sum = 0L + var i = offset + + // Sum all 16-bit words + while (i < offset + length - 1) { + sum += ((data[i].toInt() and 0xFF) shl 8) + (data[i + 1].toInt() and 0xFF) + i += 2 + } + + // Add odd byte if present + if (i < offset + length) { + sum += (data[i].toInt() and 0xFF) shl 8 + } + + // Add carry bits + while (sum shr 16 != 0L) { + sum = (sum and 0xFFFF) + (sum shr 16) + } + + return (sum.inv() and 0xFFFF).toInt() + } + }