diff --git a/src/util/cancellation.go b/src/util/cancellation.go new file mode 100644 index 00000000..2a78c19d --- /dev/null +++ b/src/util/cancellation.go @@ -0,0 +1,96 @@ +package util + +import ( + "errors" + "runtime" + "sync" + "time" +) + +type Cancellation interface { + Finished() <-chan struct{} + Cancel(error) error + Error() error +} + +func CancellationFinalizer(c Cancellation) { + c.Cancel(errors.New("finalizer called")) +} + +type cancellation struct { + signal chan error + cancel chan struct{} + errMtx sync.RWMutex + err error +} + +func (c *cancellation) worker() { + // Launch this in a separate goroutine when creating a cancellation + err := <-c.signal + c.errMtx.Lock() + c.err = err + c.errMtx.Unlock() + close(c.cancel) +} + +func NewCancellation() Cancellation { + c := cancellation{ + signal: make(chan error), + cancel: make(chan struct{}), + } + runtime.SetFinalizer(&c, CancellationFinalizer) + go c.worker() + return &c +} + +func (c *cancellation) Finished() <-chan struct{} { + return c.cancel +} + +func (c *cancellation) Cancel(err error) error { + select { + case c.signal <- err: + return nil + case <-c.cancel: + return c.Error() + } +} + +func (c *cancellation) Error() error { + c.errMtx.RLock() + err := c.err + c.errMtx.RUnlock() + return err +} + +func CancellationChild(parent Cancellation) Cancellation { + child := NewCancellation() + go func() { + select { + case <-child.Finished(): + case <-parent.Finished(): + child.Cancel(parent.Error()) + } + }() + return child +} + +var CancellationTimeoutError = errors.New("timeout") + +func CancellationWithTimeout(parent Cancellation, timeout time.Duration) Cancellation { + child := CancellationChild(parent) + go func() { + timer := time.NewTimer(timeout) + defer TimerStop(timer) + select { + case <-child.Finished(): + case <-timer.C: + child.Cancel(CancellationTimeoutError) + } + }() + return child +} + +func CancellationWithDeadline(parent Cancellation, deadline time.Time) Cancellation { + return CancellationWithTimeout(parent, deadline.Sub(time.Now())) +} diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index 1d686f83..bc884fb3 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -46,13 +46,13 @@ func (e *ConnError) Closed() bool { type Conn struct { core *Core - nodeID *crypto.NodeID - nodeMask *crypto.NodeID - mutex sync.RWMutex - close chan bool - session *sessionInfo readDeadline atomic.Value // time.Time // TODO timer writeDeadline atomic.Value // time.Time // TODO timer + cancel util.Cancellation + mutex sync.RWMutex // protects the below + nodeID *crypto.NodeID + nodeMask *crypto.NodeID + session *sessionInfo } // TODO func NewConn() that initializes additional fields as needed @@ -62,12 +62,14 @@ func newConn(core *Core, nodeID *crypto.NodeID, nodeMask *crypto.NodeID, session nodeID: nodeID, nodeMask: nodeMask, session: session, - close: make(chan bool), + cancel: util.NewCancellation(), } return &conn } func (c *Conn) String() string { + c.mutex.RLock() + defer c.mutex.RUnlock() return fmt.Sprintf("conn=%p", c) } @@ -111,28 +113,31 @@ func (c *Conn) search() error { return nil } -func getDeadlineTimer(value *atomic.Value) *time.Timer { - timer := time.NewTimer(24 * 365 * time.Hour) // FIXME for some reason setting this to 0 doesn't always let it stop and drain the channel correctly - util.TimerStop(timer) +func (c *Conn) getDeadlineCancellation(value *atomic.Value) util.Cancellation { if deadline, ok := value.Load().(time.Time); ok { - timer.Reset(time.Until(deadline)) + // A deadline is set, so return a Cancellation that uses it + return util.CancellationWithDeadline(c.cancel, deadline) + } else { + // No cancellation was set, so return a child cancellation with no timeout + return util.CancellationChild(c.cancel) } - return timer } func (c *Conn) Read(b []byte) (int, error) { // Take a copy of the session object sinfo := c.session - timer := getDeadlineTimer(&c.readDeadline) - defer util.TimerStop(timer) + cancel := c.getDeadlineCancellation(&c.readDeadline) + defer cancel.Cancel(nil) var bs []byte for { // Wait for some traffic to come through from the session select { - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } case p, ok := <-sinfo.recv: // If the session is closed then do nothing if !ok { @@ -172,18 +177,22 @@ func (c *Conn) Read(b []byte) (int, error) { // Send to worker select { case sinfo.worker <- workerFunc: - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Wait for the worker to finish select { case <-done: // Wait for the worker to finish, failing this can cause memory errors (util.[Get||Put]Bytes stuff) - case <-c.close: - return 0, ConnError{errors.New("session closed"), false, false, true, 0} - case <-timer.C: - return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("read timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Something went wrong in the session worker so abort if err != nil { @@ -256,8 +265,8 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { } } // Set up a timer so this doesn't block forever - timer := getDeadlineTimer(&c.writeDeadline) - defer util.TimerStop(timer) + cancel := c.getDeadlineCancellation(&c.writeDeadline) + defer cancel.Cancel(nil) // Hand over to the session worker defer func() { if recover() != nil { @@ -267,8 +276,12 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { }() // In case we're racing with a close select { // Send to worker case sinfo.worker <- workerFunc: - case <-timer.C: - return 0, ConnError{errors.New("write timeout"), true, false, false, 0} + case <-cancel.Finished(): + if cancel.Error() == util.CancellationTimeoutError { + return 0, ConnError{errors.New("write timeout"), true, false, false, 0} + } else { + return 0, ConnError{errors.New("session closed"), false, false, true, 0} + } } // Wait for the worker to finish, otherwise there are memory errors ([Get||Put]Bytes stuff) <-done @@ -287,13 +300,9 @@ func (c *Conn) Close() (err error) { // Close the session, if it hasn't been closed already c.core.router.doAdmin(c.session.close) } - func() { - defer func() { - recover() - err = ConnError{errors.New("close failed, session already closed"), false, false, true, 0} - }() - close(c.close) // Closes reader/writer goroutines - }() + if e := c.cancel.Cancel(errors.New("connection closed")); e != nil { + err = ConnError{errors.New("close failed, session already closed"), false, false, true, 0} + } return }