diff --git a/src/core/link.go b/src/core/link.go index b646605c..f45c2cee 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -37,6 +37,8 @@ type links struct { unix *linkUNIX // UNIX interface support socks *linkSOCKS // SOCKS interface support quic *linkQUIC // QUIC interface support + ws *linkWS // WS interface support + wss *linkWSS // WSS interface support // _links can only be modified safely from within the links actor _links map[linkInfo]*link // *link is nil if connection in progress } @@ -97,6 +99,8 @@ func (l *links) init(c *Core) error { l.unix = l.newLinkUNIX() l.socks = l.newLinkSOCKS() l.quic = l.newLinkQUIC() + l.ws = l.newLinkWS() + l.wss = l.newLinkWSS() l._links = make(map[linkInfo]*link) var listeners []ListenAddress @@ -417,6 +421,10 @@ func (l *links) listen(u *url.URL, sintf string) (*Listener, error) { protocol = l.unix case "quic": protocol = l.quic + case "ws": + protocol = l.ws + case "wss": + protocol = l.wss default: cancel() return nil, ErrLinkUnrecognisedSchema @@ -545,6 +553,10 @@ func (l *links) connect(ctx context.Context, u *url.URL, info linkInfo, options dialer = l.unix case "quic": dialer = l.quic + case "ws": + dialer = l.ws + case "wss": + dialer = l.wss default: return nil, ErrLinkUnrecognisedSchema } diff --git a/src/core/link_ws.go b/src/core/link_ws.go new file mode 100644 index 00000000..4a65f979 --- /dev/null +++ b/src/core/link_ws.go @@ -0,0 +1,123 @@ +package core + +import ( + "context" + "net" + "net/http" + "net/url" + "time" + + "github.com/Arceliar/phony" + "nhooyr.io/websocket" +) + +type linkWS struct { + phony.Inbox + *links +} + +type linkWSConn struct { + net.Conn +} + +type linkWSListener struct { + ch chan *linkWSConn + ctx context.Context + httpServer *http.Server + listener net.Listener +} + +type wsServer struct { + ch chan *linkWSConn + ctx context.Context +} + +func (l *linkWSListener) Accept() (net.Conn, error) { + qs := <-l.ch + if qs == nil { + return nil, context.Canceled + } + return qs, nil +} + +func (l *linkWSListener) Addr() net.Addr { + return l.listener.Addr() +} + +func (l *linkWSListener) Close() error { + if err := l.httpServer.Shutdown(l.ctx); err != nil { + return err + } + + return l.listener.Close() +} + +func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"ygg-ws"}, + }) + + if err != nil { + return + } + + if c.Subprotocol() != "ygg-ws" { + c.Close(websocket.StatusPolicyViolation, "client must speak the ygg-ws subprotocol") + return + } + + netconn := websocket.NetConn(s.ctx, c, websocket.MessageBinary) + + ch := s.ch + ch <- &linkWSConn{ + Conn: netconn, + } +} + +func (l *links) newLinkWS() *linkWS { + lt := &linkWS{ + links: l, + } + + return lt +} + +func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { + wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ + Subprotocols: []string{"ygg-ws"}, + }) + if err != nil { + return nil, err + } + netconn := websocket.NetConn(ctx, wsconn, websocket.MessageBinary) + return &linkWSConn{ + Conn: netconn, + }, nil +} + +func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { + nl, err := net.Listen("tcp", url.Host) + if err != nil { + return nil, err + } + + ch := make(chan *linkWSConn) + + httpServer := &http.Server{ + Handler: &wsServer{ + ch: ch, + ctx: ctx, + }, + ReadTimeout: time.Second * 10, + WriteTimeout: time.Second * 10, + } + + lwl := &linkWSListener{ + ch: ch, + ctx: ctx, + httpServer: httpServer, + listener: nl, + } + go lwl.httpServer.Serve(nl) + return lwl, nil +} diff --git a/src/core/link_wss.go b/src/core/link_wss.go new file mode 100644 index 00000000..896c4750 --- /dev/null +++ b/src/core/link_wss.go @@ -0,0 +1,132 @@ +package core + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "github.com/Arceliar/phony" + "nhooyr.io/websocket" +) + +type linkWSS struct { + phony.Inbox + tlsconfig *tls.Config + *links +} + +type linkWSSConn struct { + net.Conn +} + +type linkWSSListener struct { + ch chan *linkWSSConn + ctx context.Context + httpServer *http.Server + listener net.Listener + tlslistener net.Listener +} + +type wssServer struct { + ch chan *linkWSSConn + ctx context.Context +} + +func (l *linkWSSListener) Accept() (net.Conn, error) { + qs := <-l.ch + if qs == nil { + return nil, context.Canceled + } + return qs, nil +} + +func (l *linkWSSListener) Addr() net.Addr { + return l.listener.Addr() +} + +func (l *linkWSSListener) Close() error { + if err := l.httpServer.Shutdown(l.ctx); err != nil { + return err + } + if err := l.tlslistener.Close(); err != nil { + return err + } + return l.listener.Close() +} + +func (s *wssServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"ygg-ws"}, + }) + + if err != nil { + return + } + + if c.Subprotocol() != "ygg-ws" { + c.Close(websocket.StatusPolicyViolation, "client must speak the ygg-ws subprotocol") + return + } + + netconn := websocket.NetConn(s.ctx, c, websocket.MessageBinary) + + ch := s.ch + ch <- &linkWSSConn{ + Conn: netconn, + } +} + +func (l *links) newLinkWSS() *linkWSS { + lwss := &linkWSS{ + links: l, + tlsconfig: l.core.config.tls.Clone(), + } + + return lwss +} + +func (l *linkWSS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { + wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ + Subprotocols: []string{"ygg-ws"}, + }) + if err != nil { + return nil, err + } + netconn := websocket.NetConn(ctx, wsconn, websocket.MessageBinary) + return &linkWSSConn{ + Conn: netconn, + }, nil +} + +func (l *linkWSS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { + nl, err := net.Listen("tcp", url.Host) + if err != nil { + return nil, err + } + + tl := tls.NewListener(nl, l.tlsconfig) + + ch := make(chan *linkWSSConn) + + httpServer := &http.Server{ + Handler: &wssServer{ + ch: ch, + ctx: ctx, + }, + ReadTimeout: time.Second * 10, + WriteTimeout: time.Second * 10, + } + + lwl := &linkWSSListener{ + ch: ch, + ctx: ctx, + httpServer: httpServer, + listener: nl, + tlslistener: tl, + } + go lwl.httpServer.Serve(tl) + return lwl, nil +}