This commit is contained in:
HodlOnToYourButts 2025-08-29 00:54:46 -07:00 committed by GitHub
commit f58b66bcdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,8 +12,13 @@ import androidx.preference.PreferenceManager
import eu.neilalexander.yggdrasil.YggStateReceiver.Companion.YGG_STATE_INTENT
import mobile.Yggdrasil
import org.json.JSONArray
import java.io.FileDescriptor
import java.io.FileInputStream
import java.io.FileOutputStream
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.net.Inet6Address
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.thread
@ -35,6 +40,7 @@ open class PacketTunnelProvider: VpnService() {
private var started = AtomicBoolean()
private lateinit var config: ConfigurationProxy
private var customDnsPort: Int = 0
private var readerThread: Thread? = null
private var writerThread: Thread? = null
@ -114,6 +120,26 @@ open class PacketTunnelProvider: VpnService() {
yggdrasil.startJSON(config.getJSONByteArray())
val address = yggdrasil.addressString
var hasCustomDns = false
val preferences = PreferenceManager.getDefaultSharedPreferences(this.baseContext)
val serverString = preferences.getString(KEY_DNS_SERVERS, "")
// First, check if we have custom DNS servers to determine bypass behavior
if (serverString!!.isNotEmpty()) {
val servers = serverString.split(",")
servers.forEach { server ->
val trimmedServer = server.trim()
if (trimmedServer.startsWith("127.0.0.1:")) {
hasCustomDns = true
// Extract port number from 127.0.0.1:5353 format
val portString = trimmedServer.substring(10) // Remove "127.0.0.1:"
customDnsPort = portString.toIntOrNull() ?: 5353
Log.i(TAG, "Found custom IPv4 DNS server: $trimmedServer, port: $customDnsPort")
}
}
}
val builder = Builder()
.addAddress(address, 7)
.addRoute("200::", 7)
@ -125,7 +151,18 @@ open class PacketTunnelProvider: VpnService() {
// and we can't use DNS with Yggdrasil addresses.
.addRoute("2000::", 128)
.allowFamily(OsConstants.AF_INET)
.allowBypass()
// Only allow bypass if no custom DNS servers
if (!hasCustomDns) {
builder.allowBypass()
Log.d(TAG, "Allowing VPN bypass - no custom DNS")
} else {
Log.i(TAG, "Not allowing VPN bypass - forcing DNS through VPN")
// Add route only for our dummy DNS server
builder.addRoute("198.18.0.1", 32) // Private IPv4 DNS (dummy)
}
builder
.setBlocking(true)
.setMtu(yggdrasil.mtu.toInt())
.setSession("Yggdrasil")
@ -137,14 +174,20 @@ open class PacketTunnelProvider: VpnService() {
builder.setMetered(false)
}
val preferences = PreferenceManager.getDefaultSharedPreferences(this.baseContext)
val serverString = preferences.getString(KEY_DNS_SERVERS, "")
if (serverString!!.isNotEmpty()) {
// Now add the actual DNS servers
if (serverString.isNotEmpty()) {
val servers = serverString.split(",")
if (servers.isNotEmpty()) {
servers.forEach {
Log.i(TAG, "Using DNS server $it")
builder.addDnsServer(it)
servers.forEach { server ->
val trimmedServer = server.trim()
if (trimmedServer.startsWith("127.0.0.1:")) {
// Add only private DNS as dummy server to intercept
builder.addDnsServer("198.18.0.1") // Private IPv4 DNS (dummy)
Log.i(TAG, "Added dummy DNS server 198.18.0.1 to intercept for 127.0.0.1:$customDnsPort")
} else {
Log.i(TAG, "Using standard DNS server $trimmedServer")
builder.addDnsServer(trimmedServer)
}
}
}
}
@ -333,6 +376,37 @@ open class PacketTunnelProvider: VpnService() {
}
try {
val n = readerStream.read(b)
if (n > 0) {
if (n > 20) {
val version = (b[0].toInt() and 0xF0) shr 4
if (version == 4 && n >= 20) {
val protocol = b[9].toInt() and 0xFF
val srcIP = String.format("%d.%d.%d.%d",
b[12].toInt() and 0xFF, b[13].toInt() and 0xFF,
b[14].toInt() and 0xFF, b[15].toInt() and 0xFF)
val dstIP = String.format("%d.%d.%d.%d",
b[16].toInt() and 0xFF, b[17].toInt() and 0xFF,
b[18].toInt() and 0xFF, b[19].toInt() and 0xFF)
if (protocol == 17 && n >= 28) { // UDP
val ipHeaderLength = (b[0].toInt() and 0x0F) * 4
val srcPort = ((b[ipHeaderLength].toInt() and 0xFF) shl 8) or
(b[ipHeaderLength + 1].toInt() and 0xFF)
val destPort = ((b[ipHeaderLength + 2].toInt() and 0xFF) shl 8) or
(b[ipHeaderLength + 3].toInt() and 0xFF)
if (destPort == 53 && customDnsPort > 0) {
// Forward DNS query to custom server and inject response
forwardDnsQuery(b, n, ipHeaderLength + 8, true, srcIP, srcPort)
continue@reads // Skip normal processing
}
}
}
}
}
yggdrasil.sendBuffer(b, n.toLong())
} catch (e: Exception) {
Log.i(TAG, "Error in sendBuffer: $e")
@ -344,4 +418,149 @@ open class PacketTunnelProvider: VpnService() {
readerStream = null
}
}
private fun forwardDnsQuery(packet: ByteArray, packetLength: Int, dnsPayloadOffset: Int, isIPv4: Boolean, srcIP: String, srcPort: Int) {
try {
// Extract DNS payload
val dnsPayloadLength = packetLength - dnsPayloadOffset
val dnsPayload = ByteArray(dnsPayloadLength)
System.arraycopy(packet, dnsPayloadOffset, dnsPayload, 0, dnsPayloadLength)
// Forward to custom DNS server at 127.0.0.1:customDnsPort
thread {
var socket: DatagramSocket? = null
try {
// Create socket with no specific binding - let system choose any available port on any interface
socket = DatagramSocket()
socket.soTimeout = 1000 // 1 second timeout
val address = InetAddress.getByName("127.0.0.1")
val outPacket = DatagramPacket(dnsPayload, dnsPayload.size, address, customDnsPort)
socket.send(outPacket)
// Wait for response
val responseBuffer = ByteArray(1024)
val responsePacket = DatagramPacket(responseBuffer, responseBuffer.size)
socket.receive(responsePacket)
val responseData = ByteArray(responsePacket.length)
System.arraycopy(responseBuffer, 0, responseData, 0, responsePacket.length)
// Inject response back into VPN tunnel (IPv4 only)
injectDnsResponse(responseData, true, srcIP, srcPort)
} catch (e: java.net.SocketTimeoutException) {
Log.e(TAG, "Timeout waiting for DNS response from 127.0.0.1:$customDnsPort")
} catch (e: Exception) {
Log.e(TAG, "Error forwarding DNS query: $e")
} finally {
socket?.close()
}
}
} catch (e: Exception) {
Log.e(TAG, "Error in forwardDnsQuery: $e")
}
}
private fun injectDnsResponse(dnsResponse: ByteArray, isIPv4: Boolean, originalSrcIP: String, originalSrcPort: Int) {
try {
val writerStream = writerStream ?: return
// Create IPv4 UDP packet with DNS response from 198.18.0.1:53 back to client
val responsePacket = createIPv4UdpPacket(dnsResponse, "198.18.0.1", 53, originalSrcIP, originalSrcPort)
if (responsePacket.isNotEmpty()) {
writerStream.write(responsePacket)
} else {
Log.e(TAG, "Failed to create response packet")
}
} catch (e: Exception) {
Log.e(TAG, "Error injecting DNS response: $e")
}
}
private fun createIPv4UdpPacket(payload: ByteArray, srcIP: String, srcPort: Int, dstIP: String, dstPort: Int): ByteArray {
val totalLength = 20 + 8 + payload.size // IP header + UDP header + payload
val packet = ByteArray(totalLength)
// IPv4 Header (20 bytes)
packet[0] = 0x45 // Version 4, Header length 5 (20 bytes)
packet[1] = 0x00 // Type of Service
packet[2] = (totalLength shr 8).toByte() // Total length high byte
packet[3] = (totalLength and 0xFF).toByte() // Total length low byte
packet[4] = 0x00 // Identification high byte
packet[5] = 0x00 // Identification low byte
packet[6] = 0x40 // Flags: Don't fragment
packet[7] = 0x00 // Fragment offset
packet[8] = 64 // TTL
packet[9] = 17 // Protocol: UDP
packet[10] = 0x00 // Header checksum (will calculate)
packet[11] = 0x00 // Header checksum
// Source IP
val srcBytes = srcIP.split(".").map { it.toInt().toByte() }
packet[12] = srcBytes[0]
packet[13] = srcBytes[1]
packet[14] = srcBytes[2]
packet[15] = srcBytes[3]
// Destination IP
val dstBytes = dstIP.split(".").map { it.toInt().toByte() }
packet[16] = dstBytes[0]
packet[17] = dstBytes[1]
packet[18] = dstBytes[2]
packet[19] = dstBytes[3]
// Calculate IP header checksum
val checksum = calculateIPv4Checksum(packet, 0, 20)
packet[10] = (checksum shr 8).toByte()
packet[11] = (checksum and 0xFF).toByte()
// UDP Header (8 bytes)
val udpLength = 8 + payload.size
packet[20] = (srcPort shr 8).toByte() // Source port high byte
packet[21] = (srcPort and 0xFF).toByte() // Source port low byte
packet[22] = (dstPort shr 8).toByte() // Dest port high byte
packet[23] = (dstPort and 0xFF).toByte() // Dest port low byte
packet[24] = (udpLength shr 8).toByte() // UDP length high byte
packet[25] = (udpLength and 0xFF).toByte() // UDP length low byte
packet[26] = 0x00 // UDP checksum (optional for IPv4)
packet[27] = 0x00 // UDP checksum
// Copy payload
System.arraycopy(payload, 0, packet, 28, payload.size)
return packet
}
private fun calculateIPv4Checksum(data: ByteArray, offset: Int, length: Int): Int {
var sum = 0L
var i = offset
// Sum all 16-bit words
while (i < offset + length - 1) {
sum += ((data[i].toInt() and 0xFF) shl 8) + (data[i + 1].toInt() and 0xFF)
i += 2
}
// Add odd byte if present
if (i < offset + length) {
sum += (data[i].toInt() and 0xFF) shl 8
}
// Add carry bits
while (sum shr 16 != 0L) {
sum = (sum and 0xFFFF) + (sum shr 16)
}
return (sum.inv() and 0xFFFF).toInt()
}
}