diff --git a/src/tuntap/conn.go b/src/tuntap/conn.go index 1d47b378..ab3179bf 100644 --- a/src/tuntap/conn.go +++ b/src/tuntap/conn.go @@ -53,15 +53,14 @@ func (s *tunConn) reader() (err error) { } s.tun.log.Debugln("Starting conn reader for", s.conn.String()) defer s.tun.log.Debugln("Stopping conn reader for", s.conn.String()) - var n int - b := make([]byte, 65535) for { select { case <-s.stop: return nil default: } - if n, err = s.conn.Read(b); err != nil { + var bs []byte + if bs, err = s.conn.ReadNoCopy(); err != nil { if e, eok := err.(yggdrasil.ConnError); eok && !e.Temporary() { if e.Closed() { s.tun.log.Debugln(s.conn.String(), "TUN/TAP conn read debug:", err) @@ -70,14 +69,11 @@ func (s *tunConn) reader() (err error) { } return e } - } else if n > 0 { - bs := append(util.GetBytes(), b[:n]...) - select { - case s.tun.send <- bs: - default: - util.PutBytes(bs) - } + } else if len(bs) > 0 { + s.tun.send <- bs s.stillAlive() + } else { + util.PutBytes(bs) } } } @@ -96,12 +92,12 @@ func (s *tunConn) writer() error { select { case <-s.stop: return nil - case b, ok := <-s.send: + case bs, ok := <-s.send: if !ok { return errors.New("send closed") } // TODO write timeout and close - if _, err := s.conn.Write(b); err != nil { + if err := s.conn.WriteNoCopy(bs); err != nil { if e, eok := err.(yggdrasil.ConnError); !eok { if e.Closed() { s.tun.log.Debugln(s.conn.String(), "TUN/TAP generic write debug:", err) @@ -112,9 +108,9 @@ func (s *tunConn) writer() error { // TODO: This currently isn't aware of IPv4 for CKR ptb := &icmp.PacketTooBig{ MTU: int(e.PacketMaximumSize()), - Data: b[:900], + Data: bs[:900], } - if packet, err := CreateICMPv6(b[8:24], b[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil { + if packet, err := CreateICMPv6(bs[8:24], bs[24:40], ipv6.ICMPTypePacketTooBig, 0, ptb); err == nil { s.tun.send <- packet } } else { @@ -127,7 +123,6 @@ func (s *tunConn) writer() error { } else { s.stillAlive() } - util.PutBytes(b) } } } diff --git a/src/tuntap/iface.go b/src/tuntap/iface.go index a95dfae4..670f7829 100644 --- a/src/tuntap/iface.go +++ b/src/tuntap/iface.go @@ -139,8 +139,10 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) { continue } } - // Shift forward to avoid leaking bytes off the front of the slide when we eventually store it - bs = append(recvd[:0], bs...) + if offset != 0 { + // Shift forward to avoid leaking bytes off the front of the slice when we eventually store it + bs = append(recvd[:0], bs...) + } // From the IP header, work out what our source and destination addresses // and node IDs are. We will need these in order to work out where to send // the packet @@ -260,11 +262,8 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) { tun.mutex.Unlock() if tc != nil { for _, packet := range packets { - select { - case tc.send <- packet: - default: - util.PutBytes(packet) - } + p := packet // Possibly required because of how range + tc.send <- p } } }() @@ -274,21 +273,18 @@ func (tun *TunAdapter) readerPacketHandler(ch chan []byte) { } // If we have a connection now, try writing to it if isIn && session != nil { - select { - case session.send <- bs: - default: - util.PutBytes(bs) - } + session.send <- bs } } } func (tun *TunAdapter) reader() error { - recvd := make([]byte, 65535+tun_ETHER_HEADER_LENGTH) toWorker := make(chan []byte, 32) defer close(toWorker) go tun.readerPacketHandler(toWorker) for { + // Get a slice to store the packet in + recvd := util.ResizeBytes(util.GetBytes(), 65535+tun_ETHER_HEADER_LENGTH) // Wait for a packet to be delivered to us through the TUN/TAP adapter n, err := tun.iface.Read(recvd) if err != nil { @@ -298,9 +294,10 @@ func (tun *TunAdapter) reader() error { panic(err) } if n == 0 { + util.PutBytes(recvd) continue } - bs := append(util.GetBytes(), recvd[:n]...) - toWorker <- bs + // Send the packet to the worker + toWorker <- recvd[:n] } } diff --git a/src/util/util.go b/src/util/util.go index 6fb515c8..1158156c 100644 --- a/src/util/util.go +++ b/src/util/util.go @@ -26,27 +26,25 @@ func UnlockThread() { } // This is used to buffer recently used slices of bytes, to prevent allocations in the hot loops. -var byteStoreMutex sync.Mutex -var byteStore [][]byte +var byteStore = sync.Pool{New: func() interface{} { return []byte(nil) }} // Gets an empty slice from the byte store. func GetBytes() []byte { - byteStoreMutex.Lock() - defer byteStoreMutex.Unlock() - if len(byteStore) > 0 { - var bs []byte - bs, byteStore = byteStore[len(byteStore)-1][:0], byteStore[:len(byteStore)-1] - return bs - } else { - return nil - } + return byteStore.Get().([]byte)[:0] } // Puts a slice in the store. func PutBytes(bs []byte) { - byteStoreMutex.Lock() - defer byteStoreMutex.Unlock() - byteStore = append(byteStore, bs) + byteStore.Put(bs) +} + +// Gets a slice of the appropriate length, reusing existing slice capacity when possible +func ResizeBytes(bs []byte, length int) []byte { + if cap(bs) >= length { + return bs[:length] + } else { + return make([]byte, length) + } } // This is a workaround to go's broken timer implementation diff --git a/src/util/workerpool.go b/src/util/workerpool.go new file mode 100644 index 00000000..fd37f397 --- /dev/null +++ b/src/util/workerpool.go @@ -0,0 +1,29 @@ +package util + +import "runtime" + +var workerPool chan func() + +func init() { + maxProcs := runtime.GOMAXPROCS(0) + if maxProcs < 1 { + maxProcs = 1 + } + workerPool = make(chan func(), maxProcs) + for idx := 0; idx < maxProcs; idx++ { + go func() { + for f := range workerPool { + f() + } + }() + } +} + +// WorkerGo submits a job to a pool of GOMAXPROCS worker goroutines. +// This is meant for short non-blocking functions f() where you could just go f(), +// but you want some kind of backpressure to prevent spawning endless goroutines. +// WorkerGo returns as soon as the function is queued to run, not when it finishes. +// In Yggdrasil, these workers are used for certain cryptographic operations. +func WorkerGo(f func()) { + workerPool <- f +} diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index d1cb7609..1a05bd83 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -82,7 +82,7 @@ func (c *Conn) String() string { return fmt.Sprintf("conn=%p", c) } -// This should never be called from the router goroutine +// This should never be called from the router goroutine, used in the dial functions func (c *Conn) search() error { var sinfo *searchInfo var isIn bool @@ -122,6 +122,23 @@ func (c *Conn) search() error { return nil } +// Used in session keep-alive traffic in Conn.Write +func (c *Conn) doSearch() { + routerWork := func() { + // Check to see if there is a search already matching the destination + sinfo, isIn := c.core.searches.searches[*c.nodeID] + if !isIn { + // Nothing was found, so create a new search + searchCompleted := func(sinfo *sessionInfo, e error) {} + sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted) + c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo) + } + // Continue the search + sinfo.continueSearch() + } + go func() { c.core.router.admin <- routerWork }() +} + func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation { if deadline, ok := value.Load().(time.Time); ok { // A deadline is set, so return a Cancellation that uses it @@ -132,123 +149,90 @@ func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation { } } -func (c *Conn) Read(b []byte) (int, error) { - // Take a copy of the session object - sinfo := c.session +// Used internally by Read, the caller is responsible for util.PutBytes when they're done. +func (c *Conn) ReadNoCopy() ([]byte, error) { cancel := c.getDeadlineCancellation(&c.readDeadline) defer cancel.Cancel(nil) - var bs []byte - for { - // Wait for some traffic to come through from the session - select { - case <-cancel.Finished(): - if cancel.Error() == util.CancellationTimeoutError { - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} - } else { - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - } - case p, ok := <-sinfo.recv: - // If the session is closed then do nothing - if !ok { - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - } - var err error - sessionFunc := func() { - defer util.PutBytes(p.Payload) - // If the nonce is bad then drop the packet and return an error - if !sinfo.nonceIsOK(&p.Nonce) { - err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0} - return - } - // Decrypt the packet - var isOK bool - bs, isOK = crypto.BoxOpen(&sinfo.sharedSesKey, p.Payload, &p.Nonce) - // Check if we were unable to decrypt the packet for some reason and - // return an error if we couldn't - if !isOK { - err = ConnError{errors.New("packet dropped due to decryption failure"), false, true, false, 0} - return - } - // Update the session - sinfo.updateNonce(&p.Nonce) - sinfo.time = time.Now() - sinfo.bytesRecvd += uint64(len(bs)) - } - sinfo.doFunc(sessionFunc) - // Something went wrong in the session worker so abort - if err != nil { - if ce, ok := err.(*ConnError); ok && ce.Temporary() { - continue - } - return 0, err - } - // Copy results to the output slice and clean up - copy(b, bs) - util.PutBytes(bs) - // If we've reached this point then everything went to plan, return the - // number of bytes we populated back into the given slice - return len(bs), nil + // Wait for some traffic to come through from the session + select { + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return nil, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return nil, ConnError{errors.New("session closed"), false, false, true, 0} } + case bs := <-c.session.recv: + return bs, nil } } -func (c *Conn) Write(b []byte) (bytesWritten int, err error) { - sinfo := c.session - var packet []byte - written := len(b) +// Implements net.Conn.Read +func (c *Conn) Read(b []byte) (int, error) { + bs, err := c.ReadNoCopy() + if err != nil { + return 0, err + } + n := len(bs) + if len(bs) > len(b) { + n = len(b) + err = ConnError{errors.New("read buffer too small for entire packet"), false, true, false, 0} + } + // Copy results to the output slice and clean up + copy(b, bs) + util.PutBytes(bs) + // Return the number of bytes copied to the slice, along with any error + return n, err +} + +// Used internally by Write, the caller must not reuse the argument bytes when no error occurs +func (c *Conn) WriteNoCopy(bs []byte) error { + var err error sessionFunc := func() { // Does the packet exceed the permitted size for the session? - if uint16(len(b)) > sinfo.getMTU() { - written, err = 0, ConnError{errors.New("packet too big"), true, false, false, int(sinfo.getMTU())} + if uint16(len(bs)) > c.session.getMTU() { + err = ConnError{errors.New("packet too big"), true, false, false, int(c.session.getMTU())} return } - // Encrypt the packet - payload, nonce := crypto.BoxSeal(&sinfo.sharedSesKey, b, &sinfo.myNonce) - defer util.PutBytes(payload) - // Construct the wire packet to send to the router - p := wire_trafficPacket{ - Coords: sinfo.coords, - Handle: sinfo.theirHandle, - Nonce: *nonce, - Payload: payload, - } - packet = p.encode() - sinfo.bytesSent += uint64(len(b)) // The rest of this work is session keep-alive traffic - doSearch := func() { - routerWork := func() { - // Check to see if there is a search already matching the destination - sinfo, isIn := c.core.searches.searches[*c.nodeID] - if !isIn { - // Nothing was found, so create a new search - searchCompleted := func(sinfo *sessionInfo, e error) {} - sinfo = c.core.searches.newIterSearch(c.nodeID, c.nodeMask, searchCompleted) - c.core.log.Debugf("%s DHT search started: %p", c.String(), sinfo) - } - // Continue the search - sinfo.continueSearch() - } - go func() { c.core.router.admin <- routerWork }() - } switch { - case time.Since(sinfo.time) > 6*time.Second: - if sinfo.time.Before(sinfo.pingTime) && time.Since(sinfo.pingTime) > 6*time.Second { + case time.Since(c.session.time) > 6*time.Second: + if c.session.time.Before(c.session.pingTime) && time.Since(c.session.pingTime) > 6*time.Second { // TODO double check that the above condition is correct - doSearch() + c.doSearch() } else { - sinfo.core.sessions.ping(sinfo) + c.core.sessions.ping(c.session) } - case sinfo.reset && sinfo.pingTime.Before(sinfo.time): - sinfo.core.sessions.ping(sinfo) + case c.session.reset && c.session.pingTime.Before(c.session.time): + c.core.sessions.ping(c.session) default: // Don't do anything, to keep traffic throttled } } - sinfo.doFunc(sessionFunc) - // Give the packet to the router - if written > 0 { - sinfo.core.router.out(packet) + c.session.doFunc(sessionFunc) + if err == nil { + cancel := c.getDeadlineCancellation(&c.writeDeadline) + defer cancel.Cancel(nil) + select { + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + err = ConnError{errors.New("write timeout"), true, false, false, 0} + } else { + err = ConnError{errors.New("session closed"), false, false, true, 0} + } + case c.session.send <- bs: + } + } + return err +} + +// Implements net.Conn.Write +func (c *Conn) Write(b []byte) (int, error) { + written := len(b) + bs := append(util.GetBytes(), b...) + err := c.WriteNoCopy(bs) + if err != nil { + util.PutBytes(bs) + written = 0 } - // Finally return the number of bytes we wrote return written, err } diff --git a/src/yggdrasil/dialer.go b/src/yggdrasil/dialer.go index 6b24cfb4..6ce2e8ac 100644 --- a/src/yggdrasil/dialer.go +++ b/src/yggdrasil/dialer.go @@ -69,6 +69,7 @@ func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, er defer t.Stop() select { case <-conn.session.init: + conn.session.startWorkers(conn.cancel) return conn, nil case <-t.C: conn.Close() diff --git a/src/yggdrasil/router.go b/src/yggdrasil/router.go index c5e1dde0..2df7684f 100644 --- a/src/yggdrasil/router.go +++ b/src/yggdrasil/router.go @@ -127,7 +127,6 @@ func (r *router) mainLoop() { r.core.switchTable.doMaintenance() r.core.dht.doMaintenance() r.core.sessions.cleanup() - util.GetBytes() // To slowly drain things } case f := <-r.admin: f() @@ -166,8 +165,8 @@ func (r *router) handleTraffic(packet []byte) { return } select { - case sinfo.recv <- &p: // FIXME ideally this should be front drop - default: + case sinfo.fromRouter <- &p: + case <-sinfo.cancel.Finished(): util.PutBytes(p.Payload) } } diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index eca3bb00..c39f60de 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -6,11 +6,13 @@ package yggdrasil import ( "bytes" + "errors" "sync" "time" "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" + "github.com/yggdrasil-network/yggdrasil-go/src/util" ) // All the information we know about an active session. @@ -44,8 +46,11 @@ type sessionInfo struct { tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation bytesSent uint64 // Bytes of real traffic sent in this session bytesRecvd uint64 // Bytes of real traffic received in this session - recv chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn + fromRouter chan *wire_trafficPacket // Received packets go here, picked up by the associated Conn init chan struct{} // Closed when the first session pong arrives, used to signal that the session is ready for initial use + cancel util.Cancellation // Used to terminate workers + recv chan []byte + send chan []byte } func (sinfo *sessionInfo) doFunc(f func()) { @@ -222,7 +227,9 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { sinfo.myHandle = *crypto.NewHandle() sinfo.theirAddr = *address.AddrForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) sinfo.theirSubnet = *address.SubnetForNodeID(crypto.GetNodeID(&sinfo.theirPermPub)) - sinfo.recv = make(chan *wire_trafficPacket, 32) + sinfo.fromRouter = make(chan *wire_trafficPacket, 1) + sinfo.recv = make(chan []byte, 32) + sinfo.send = make(chan []byte, 32) ss.sinfos[sinfo.myHandle] = &sinfo ss.byTheirPerm[sinfo.theirPermPub] = &sinfo.myHandle return &sinfo @@ -355,6 +362,7 @@ func (ss *sessions) handlePing(ping *sessionPing) { for i := range conn.nodeMask { conn.nodeMask[i] = 0xFF } + conn.session.startWorkers(conn.cancel) ss.listener.conn <- conn } ss.listenerMutex.Unlock() @@ -418,3 +426,150 @@ func (ss *sessions) reset() { }) } } + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////// Worker Functions Below //////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +func (sinfo *sessionInfo) startWorkers(cancel util.Cancellation) { + sinfo.cancel = cancel + go sinfo.recvWorker() + go sinfo.sendWorker() +} + +func (sinfo *sessionInfo) recvWorker() { + // TODO move theirNonce etc into a struct that gets stored here, passed in over a channel + // Since there's no reason for anywhere else in the session code to need to *read* it... + // Only needs to be updated from the outside if a ping resets it... + // That would get rid of the need to take a mutex for the sessionFunc + var callbacks []chan func() + doRecv := func(p *wire_trafficPacket) { + var bs []byte + var err error + var k crypto.BoxSharedKey + sessionFunc := func() { + if !sinfo.nonceIsOK(&p.Nonce) { + err = ConnError{errors.New("packet dropped due to invalid nonce"), false, true, false, 0} + return + } + k = sinfo.sharedSesKey + } + sinfo.doFunc(sessionFunc) + if err != nil { + util.PutBytes(p.Payload) + return + } + var isOK bool + ch := make(chan func(), 1) + poolFunc := func() { + bs, isOK = crypto.BoxOpen(&k, p.Payload, &p.Nonce) + callback := func() { + util.PutBytes(p.Payload) + if !isOK { + util.PutBytes(bs) + return + } + sessionFunc = func() { + if k != sinfo.sharedSesKey || !sinfo.nonceIsOK(&p.Nonce) { + // The session updated in the mean time, so return an error + err = ConnError{errors.New("session updated during crypto operation"), false, true, false, 0} + return + } + sinfo.updateNonce(&p.Nonce) + sinfo.time = time.Now() + sinfo.bytesRecvd += uint64(len(bs)) + } + sinfo.doFunc(sessionFunc) + if err != nil { + // Not sure what else to do with this packet, I guess just drop it + util.PutBytes(bs) + } else { + // Pass the packet to the buffer for Conn.Read + sinfo.recv <- bs + } + } + ch <- callback + } + // Send to the worker and wait for it to finish + util.WorkerGo(poolFunc) + callbacks = append(callbacks, ch) + } + for { + for len(callbacks) > 0 { + select { + case f := <-callbacks[0]: + callbacks = callbacks[1:] + f() + case <-sinfo.cancel.Finished(): + return + case p := <-sinfo.fromRouter: + doRecv(p) + } + } + select { + case <-sinfo.cancel.Finished(): + return + case p := <-sinfo.fromRouter: + doRecv(p) + } + } +} + +func (sinfo *sessionInfo) sendWorker() { + // TODO move info that this worker needs here, send updates via a channel + // Otherwise we need to take a mutex to avoid races with update() + var callbacks []chan func() + doSend := func(bs []byte) { + var p wire_trafficPacket + var k crypto.BoxSharedKey + sessionFunc := func() { + sinfo.bytesSent += uint64(len(bs)) + p = wire_trafficPacket{ + Coords: append([]byte(nil), sinfo.coords...), + Handle: sinfo.theirHandle, + Nonce: sinfo.myNonce, + } + sinfo.myNonce.Increment() + k = sinfo.sharedSesKey + } + // Get the mutex-protected info needed to encrypt the packet + sinfo.doFunc(sessionFunc) + ch := make(chan func(), 1) + poolFunc := func() { + // Encrypt the packet + p.Payload, _ = crypto.BoxSeal(&k, bs, &p.Nonce) + packet := p.encode() + // The callback will send the packet + callback := func() { + // Cleanup + util.PutBytes(bs) + util.PutBytes(p.Payload) + // Send the packet + sinfo.core.router.out(packet) + } + ch <- callback + } + // Send to the worker and wait for it to finish + util.WorkerGo(poolFunc) + callbacks = append(callbacks, ch) + } + for { + for len(callbacks) > 0 { + select { + case f := <-callbacks[0]: + callbacks = callbacks[1:] + f() + case <-sinfo.cancel.Finished(): + return + case bs := <-sinfo.send: + doSend(bs) + } + } + select { + case <-sinfo.cancel.Finished(): + return + case bs := <-sinfo.send: + doSend(bs) + } + } +} diff --git a/src/yggdrasil/stream.go b/src/yggdrasil/stream.go index 4d73844f..011943f5 100644 --- a/src/yggdrasil/stream.go +++ b/src/yggdrasil/stream.go @@ -1,9 +1,11 @@ package yggdrasil import ( + "bufio" "errors" "fmt" "io" + "net" "github.com/yggdrasil-network/yggdrasil-go/src/util" ) @@ -13,9 +15,8 @@ var _ = linkInterfaceMsgIO(&stream{}) type stream struct { rwc io.ReadWriteCloser - inputBuffer []byte // Incoming packet stream - frag [2 * streamMsgSize]byte // Temporary data read off the underlying rwc, on its way to the inputBuffer - outputBuffer [2 * streamMsgSize]byte // Temporary data about to be written to the rwc + inputBuffer *bufio.Reader + outputBuffer net.Buffers } func (s *stream) close() error { @@ -30,19 +31,23 @@ func (s *stream) init(rwc io.ReadWriteCloser) { // TODO have this also do the metadata handshake and create the peer struct s.rwc = rwc // TODO call something to do the metadata exchange + s.inputBuffer = bufio.NewReaderSize(s.rwc, 2*streamMsgSize) } // writeMsg writes a message with stream padding, and is *not* thread safe. func (s *stream) writeMsg(bs []byte) (int, error) { buf := s.outputBuffer[:0] - buf = append(buf, streamMsg[:]...) - buf = wire_put_uint64(uint64(len(bs)), buf) - padLen := len(buf) - buf = append(buf, bs...) + buf = append(buf, streamMsg[:]) + l := wire_put_uint64(uint64(len(bs)), util.GetBytes()) + defer util.PutBytes(l) + buf = append(buf, l) + padLen := len(buf[0]) + len(buf[1]) + buf = append(buf, bs) + totalLen := padLen + len(bs) var bn int - for bn < len(buf) { - n, err := s.rwc.Write(buf[bn:]) - bn += n + for bn < totalLen { + n, err := buf.WriteTo(s.rwc) + bn += int(n) if err != nil { l := bn - padLen if l < 0 { @@ -57,26 +62,11 @@ func (s *stream) writeMsg(bs []byte) (int, error) { // readMsg reads a message from the stream, accounting for stream padding, and is *not* thread safe. func (s *stream) readMsg() ([]byte, error) { for { - buf := s.inputBuffer - msg, ok, err := stream_chopMsg(&buf) - switch { - case err != nil: - // Something in the stream format is corrupt + bs, err := s.readMsgFromBuffer() + if err != nil { return nil, fmt.Errorf("message error: %v", err) - case ok: - // Copy the packet into bs, shift the buffer, and return - msg = append(util.GetBytes(), msg...) - s.inputBuffer = append(s.inputBuffer[:0], buf...) - return msg, nil - default: - // Wait for the underlying reader to return enough info for us to proceed - n, err := s.rwc.Read(s.frag[:]) - if n > 0 { - s.inputBuffer = append(s.inputBuffer, s.frag[:n]...) - } else if err != nil { - return nil, err - } } + return bs, err } } @@ -108,34 +98,30 @@ func (s *stream) _recvMetaBytes() ([]byte, error) { return metaBytes, nil } -// This takes a pointer to a slice as an argument. It checks if there's a -// complete message and, if so, slices out those parts and returns the message, -// true, and nil. If there's no error, but also no complete message, it returns -// nil, false, and nil. If there's an error, it returns nil, false, and the -// error, which the reader then handles (currently, by returning from the -// reader, which causes the connection to close). -func stream_chopMsg(bs *[]byte) ([]byte, bool, error) { - // Returns msg, ok, err - if len(*bs) < len(streamMsg) { - return nil, false, nil +// Reads bytes from the underlying rwc and returns 1 full message +func (s *stream) readMsgFromBuffer() ([]byte, error) { + pad := streamMsg // Copy + _, err := io.ReadFull(s.inputBuffer, pad[:]) + if err != nil { + return nil, err + } else if pad != streamMsg { + return nil, errors.New("bad message") } - for idx := range streamMsg { - if (*bs)[idx] != streamMsg[idx] { - return nil, false, errors.New("bad message") + lenSlice := make([]byte, 0, 10) + // FIXME this nextByte stuff depends on wire.go format, kind of ugly to have it here + nextByte := byte(0xff) + for nextByte > 127 { + nextByte, err = s.inputBuffer.ReadByte() + if err != nil { + return nil, err } + lenSlice = append(lenSlice, nextByte) } - msgLen, msgLenLen := wire_decode_uint64((*bs)[len(streamMsg):]) + msgLen, _ := wire_decode_uint64(lenSlice) if msgLen > streamMsgSize { - return nil, false, errors.New("oversized message") + return nil, errors.New("oversized message") } - msgBegin := len(streamMsg) + msgLenLen - msgEnd := msgBegin + int(msgLen) - if msgLenLen == 0 || len(*bs) < msgEnd { - // We don't have the full message - // Need to buffer this and wait for the rest to come in - return nil, false, nil - } - msg := (*bs)[msgBegin:msgEnd] - (*bs) = (*bs)[msgEnd:] - return msg, true, nil + msg := util.ResizeBytes(util.GetBytes(), int(msgLen)) + _, err = io.ReadFull(s.inputBuffer, msg) + return msg, err } diff --git a/src/yggdrasil/switch.go b/src/yggdrasil/switch.go index 1bc40501..b53229cc 100644 --- a/src/yggdrasil/switch.go +++ b/src/yggdrasil/switch.go @@ -814,17 +814,23 @@ func (t *switchTable) doWorker() { go func() { // Keep taking packets from the idle worker and sending them to the above whenever it's idle, keeping anything extra in a (fifo, head-drop) buffer var buf [][]byte + var size int for { - buf = append(buf, <-t.toRouter) + bs := <-t.toRouter + size += len(bs) + buf = append(buf, bs) for len(buf) > 0 { select { case bs := <-t.toRouter: + size += len(bs) buf = append(buf, bs) - for len(buf) > 32 { + for size > int(t.queueTotalMaxSize) { + size -= len(buf[0]) util.PutBytes(buf[0]) buf = buf[1:] } case sendingToRouter <- buf[0]: + size -= len(buf[0]) buf = buf[1:] } }