diff --git a/hkexnet/consts.go b/hkexnet/consts.go index f37ea29..6188515 100644 --- a/hkexnet/consts.go +++ b/hkexnet/consts.go @@ -59,23 +59,26 @@ const ( CSOTunSetupAck // server -> client tunnel setup ack CSOTunAccept // client -> server: tunnel client got an Accept() // (Do we need a CSOTunAcceptAck server->client?) - CSOTunRefused // server -> client: tunnel rport connection refused - CSOTunData // packet contains tunnel data [rport:data] - CSOTunDisconn // server -> client: tunnel rport disconnected - CSOTunHangup // client -> server: tunnel lport hung up + CSOTunRefused // server -> client: tunnel rport connection refused + CSOTunData // packet contains tunnel data [rport:data] + CSOTunDisconn // server -> client: tunnel rport disconnected + CSOTunHangup // client -> server: tunnel lport hung up ) // TunEndpoint.tunCtl control values - used to control workers for client or server tunnels // depending on the code const ( TunCtl_Client_Listen = 'a' - + TunCtl_Server_Dial = 'd' // server has dialled OK, client side can accept() conns // [CSOTunAccept] // status: client listen() worker accepted conn on lport // action:server side should dial() rport on client's behalf - TunCtl_Info_Hangup = 'h' // client side has hung up + // -rlm 20181111 - useless as serverTun worker might in within a Read() or Write(), + // so timeouts must be used and tun.Died flag + // -- + //TunCtl_Info_Hangup = 'h' // client side has hung up // [CSOTunHangup] // status: client side conn hung up from lport // action:server side should hang up on rport, on client's behalf @@ -83,8 +86,11 @@ const ( TunCtl_Info_ConnRefused = 'r' // server side couldn't complete tunnel // [CSOTunRefused] // status:server side could not dial() remote side - - TunCtl_Info_LostConn = 'x' // server side disconnected + + // -rlm 20181111 - useless as clientTun worker might in within a Read() or Write(), + // so timeouts must be used and tun.Died flag + // -- + //TunCtl_Info_LostConn = 'x' // server side disconnected // [CSOTunDisconn] // status:server side lost connection to rport // action:client should disconnect accepted lport connection diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 441b7c3..955591e 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -842,21 +842,26 @@ func (hc Conn) Read(b []byte) (n int, err error) { lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) - //(*hc.tuns)[rport].Ctl <- 'r' // client should NOT Listen() } else if ctrlStatOp == CSOTunDisconn { // server side's rport has disconnected (server lost) lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) - // 20181111 rlm: I think we need to kick client workers out of pending Read()s here, - // only way is by forcibly closing the net conn. - (*hc.tuns)[rport].Ctl <- 'x' // client should hangup on current lport conn + if _, ok := (*hc.tuns)[rport]; ok { + (*hc.tuns)[rport].Died = true + } else { + logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) + } } else if ctrlStatOp == CSOTunHangup { // client side's lport has hung up lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) - (*hc.tuns)[rport].Ctl <- 'h' // server should hang up on currently-dialled rport + if _, ok := (*hc.tuns)[rport]; ok { + (*hc.tuns)[rport].Died = true + } else { + logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) + } } else if ctrlStatOp == CSOTunData { lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) diff --git a/hkexnet/hkextun.go b/hkexnet/hkextun.go index 5e6986b..96042db 100644 --- a/hkexnet/hkextun.go +++ b/hkexnet/hkextun.go @@ -14,6 +14,8 @@ import ( "fmt" "io" "net" + "strings" + "sync" "time" "blitter.com/go/hkexsh/logger" @@ -43,11 +45,26 @@ type ( Rport uint16 // Names are from client's perspective Lport uint16 // ... ie., RPort is on server, LPort is on client Peer string //net.Addr + Died bool // set by client upon receipt of a CSOTunDisconn Ctl chan rune //See TunCtl_* consts Data chan []byte } ) +func (hc *Conn) CollapseAllTunnels(client bool) { + for k,t := range *hc.tuns { + var tunDst bytes.Buffer + binary.Write(&tunDst, binary.BigEndian, t.Lport) + binary.Write(&tunDst, binary.BigEndian, t.Rport) + if client { + hc.WritePacket(tunDst.Bytes(), CSOTunHangup) + } else { + hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) + } + delete(*hc.tuns, k) + } +} + func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { if (*hc.tuns) == nil { (*hc.tuns) = make(map[uint16]*TunEndpoint) @@ -64,29 +81,40 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { logger.LogDebug(fmt.Sprintf("InitTunEndpoint [%d:%s:%d]", lp, p, rp)) } else { logger.LogDebug(fmt.Sprintf("InitTunEndpoint [reusing] [%d:%s:%d]", (*hc.tuns)[rp].Lport, (*hc.tuns)[rp].Peer, (*hc.tuns)[rp].Rport)) + if (*hc.tuns)[rp].Data == nil { + // When re-using a tunnel it will have its + // data channel removed on closure. Re-create it + (*hc.tuns)[rp].Data = make(chan []byte, 1) + } + (*hc.tuns)[rp].Died = false } return } func (hc *Conn) StartClientTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) - t := (*hc.tuns)[rport] // for convenience var l HKExListener go func() { + var wg sync.WaitGroup weAreListening := false - for cmd := range t.Ctl { + for cmd := range (*hc.tuns)[rport].Ctl { logger.LogDebug(fmt.Sprintf("[ClientTun] Listening for client tunnel port %d", lport)) if cmd == 'a' && !weAreListening { - l, e := net.Listen("tcp", fmt.Sprintf(":%d", lport)) + l, e := net.Listen("tcp4", fmt.Sprintf(":%d", lport)) if e != nil { logger.LogDebug(fmt.Sprintf("[ClientTun] Could not get lport %d! (%s)", lport, e)) } else { weAreListening = true for { + // If tunnel is being re-used, re-init it + if (*hc.tuns)[rport] == nil { + hc.InitTunEndpoint(lport, "", rport) + } c, e := l.Accept() - var tunDst bytes.Buffer + // ask server to dial() its side, rport + var tunDst bytes.Buffer binary.Write(&tunDst, binary.BigEndian, lport) binary.Write(&tunDst, binary.BigEndian, rport) hc.WritePacket(tunDst.Bytes(), CSOTunSetup) @@ -95,15 +123,18 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { logger.LogDebug(fmt.Sprintf("[ClientTun] Accept() got error(%v), hanging up.", e)) //break } else { - logger.LogDebug(fmt.Sprintf("[ClientTun] Accepted tunnel client %v", t)) + logger.LogDebug(fmt.Sprintf("[ClientTun] Accepted tunnel client %v", (*hc.tuns)[rport])) - c.SetDeadline(time.Now().Add(10 * time.Second)) // outside client -> tunnel lport + wg.Add(1) go func() { defer func() { if c.Close() != nil { logger.LogDebug("[ClientTun] worker A: conn c already closed") + } else { + logger.LogDebug("[ClientTun] worker A: closed conn c") } + wg.Done() }() logger.LogDebug("[ClientTun] worker A: starting") @@ -114,22 +145,48 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { for { rBuf := make([]byte, 1024) //Read data from c, encrypt/write via hc to client(lport) + c.SetReadDeadline(time.Now().Add(20 * time.Second)) n, e := c.Read(rBuf) if e != nil { if e == io.EOF { - logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", t)) + logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) + // if Died was already set, server-side already is gone. + if !(*hc.tuns)[rport].Died { + hc.WritePacket(tunDst.Bytes(), CSOTunHangup) + } + (*hc.tuns)[rport].Died = true + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break + } else if strings.Contains(e.Error(), "i/o timeout") { + if (*hc.tuns)[rport].Died { + logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break + } } else { - logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", t, e)) + logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e)) + if !(*hc.tuns)[rport].Died { + hc.WritePacket(tunDst.Bytes(), CSOTunHangup) + } + (*hc.tuns)[rport].Died = true + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break } - hc.WritePacket(tunDst.Bytes(), CSOTunHangup) - break } - c.SetDeadline(time.Now().Add(10 * time.Second)) if n > 0 { rBuf = append(tunDst.Bytes(), rBuf[:n]...) _, de := hc.WritePacket(rBuf[:n+4], CSOTunData) if de != nil { - logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Error writing to tunnel %v, %s]\n", t, de)) + logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Error writing to tunnel %v, %s]\n", (*hc.tuns)[rport], de)) break } } @@ -138,25 +195,30 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { }() // tunnel lport -> outside client (c) + wg.Add(1) go func() { defer func() { if c.Close() != nil { logger.LogDebug("[ClientTun] worker B: conn c already closed") + } else { + logger.LogDebug("[ClientTun] worker B: closed conn c") } + wg.Done() }() logger.LogDebug("[ClientTun] worker B: starting") for { - bytes, ok := <-t.Data + bytes, ok := <-(*hc.tuns)[rport].Data if ok { + c.SetWriteDeadline(time.Now().Add(20 * time.Second)) _, e := c.Write(bytes) if e != nil { logger.LogDebug(fmt.Sprintf("[ClientTun] worker B: lport conn closed")) break } } else { - logger.LogDebug(fmt.Sprintf("[ClientTun] worker B: Channel closed?")) + logger.LogDebug(fmt.Sprintf("[ClientTun] worker B: Channel was closed?")) break } } @@ -164,50 +226,60 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { }() } // end Accept() worker block + wg.Wait() + logger.LogDebug("[ClientTun] workers exited") + delete((*hc.tuns), rport) } // end for-accept } // end Listen() block } else if cmd == 'r' { - logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunRefused %v\n", t)) - } else if cmd == 'x' { - logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunDisconn, closing lport %v\n", t)) - l.Close() - weAreListening = false + logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunRefused %v\n", (*hc.tuns)[rport])) } + _ = l //else if cmd == 'x' { + //logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunDisconn, closing lport %v\n", t)) + //l.Close() + //weAreListening = false + //} } // end t.Ctl for }() } func (hc *Conn) StartServerTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) - t := (*hc.tuns)[rport] // for convenience var err error go func() { + var wg sync.WaitGroup + weAreDialled := false - for cmd := range t.Ctl { + for cmd := range (*hc.tuns)[rport].Ctl { var c net.Conn logger.LogDebug(fmt.Sprintf("[ServerTun] got Ctl '%c'. weAreDialled: %v", cmd, weAreDialled)) if cmd == 'd' && !weAreDialled { + // if re-using tunnel, re-init it + if (*hc.tuns)[rport] == nil { + hc.InitTunEndpoint(lport, "", rport) + } logger.LogDebug("[ServerTun] dialling...") - c, err = net.Dial("tcp", fmt.Sprintf(":%d", rport)) + c, err = net.Dial("tcp4", fmt.Sprintf(":%d", rport)) if err != nil { - logger.LogDebug(fmt.Sprintf("[ServerTun] Dial() error for tun %v: %s", t, err)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Dial() error for tun %v: %s", (*hc.tuns)[rport], err)) var resp bytes.Buffer binary.Write(&resp, binary.BigEndian /*lport*/, uint16(0)) binary.Write(&resp, binary.BigEndian, rport) hc.WritePacket(resp.Bytes(), CSOTunRefused) } else { - logger.LogDebug(fmt.Sprintf("[ServerTun] Tunnel Opened - %v", t)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Tunnel Opened - %v", (*hc.tuns)[rport])) weAreDialled = true var resp bytes.Buffer binary.Write(&resp, binary.BigEndian, lport) binary.Write(&resp, binary.BigEndian, rport) - logger.LogDebug(fmt.Sprintf("[ServerTun] Writing CSOTunSetupAck %v", t)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Writing CSOTunSetupAck %v", (*hc.tuns)[rport])) hc.WritePacket(resp.Bytes(), CSOTunSetupAck) // // worker to read data from the rport (to encrypt & send to client) // + wg.Add(1) go func() { defer func() { logger.LogDebug("[ServerTun] worker A: deferred hangup") @@ -215,29 +287,54 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug("[ServerTun] workerA: conn c already closed") } weAreDialled = false + wg.Done() }() logger.LogDebug("[ServerTun] worker A: starting") var tunDst bytes.Buffer - binary.Write(&tunDst, binary.BigEndian, t.Lport) - binary.Write(&tunDst, binary.BigEndian, t.Rport) + binary.Write(&tunDst, binary.BigEndian, (*hc.tuns)[rport].Lport) + binary.Write(&tunDst, binary.BigEndian, (*hc.tuns)[rport].Rport) for { rBuf := make([]byte, 1024) // Read data from c, encrypt/write via hc to client(lport) + c.SetReadDeadline(time.Now().Add(20 * time.Second)) n, e := c.Read(rBuf) if e != nil { if e == io.EOF { - logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", t)) + logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) + if !(*hc.tuns)[rport].Died { + hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) + } + (*hc.tuns)[rport].Died = true + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break + } else if strings.Contains(e.Error(), "i/o timeout") { + if (*hc.tuns)[rport].Died { + logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) + //hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) + //(*hc.tuns)[rport].Died = true + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break + } } else { - logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", t, e)) + logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e)) + if !(*hc.tuns)[rport].Died { + hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) + } + (*hc.tuns)[rport].Died = true + if (*hc.tuns)[rport].Data != nil { + close((*hc.tuns)[rport].Data) + (*hc.tuns)[rport].Data = nil + } + break } - var resp bytes.Buffer - binary.Write(&resp, binary.BigEndian, lport) - binary.Write(&resp, binary.BigEndian, rport) - hc.WritePacket(resp.Bytes(), CSOTunDisconn) - logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Closing server rport %d net.Dial()", t.Rport)) - break } if n > 0 { rBuf = append(tunDst.Bytes(), rBuf[:n]...) @@ -248,6 +345,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { }() // worker to read data from client (already decrypted) & fwd to rport + wg.Add(1) go func() { defer func() { logger.LogDebug("[ServerTun] worker B: deferred hangup") @@ -255,30 +353,32 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug("[ServerTun] worker B: conn c already closed") } weAreDialled = false + wg.Done() }() logger.LogDebug("[ServerTun] worker B: starting") for { - rData, ok := <-t.Data + rData, ok := <-(*hc.tuns)[rport].Data if ok { + c.SetWriteDeadline(time.Now().Add(20 * time.Second)) _, e := c.Write(rData) if e != nil { logger.LogDebug(fmt.Sprintf("[ServerTun] worker B: ERROR writing to rport conn")) break } } else { - logger.LogDebug("[ServerTun] worker B: ERROR reading from hc.tuns[] channel - closed?") + logger.LogDebug(fmt.Sprintf("[ServerTun] worker B: Channel was closed?")) break } } logger.LogDebug("[ServerTun] worker B: exiting") }() - } - } else if cmd == 'h' { - // client side has hung up - logger.LogDebug(fmt.Sprintf("[ServerTun] Client hung up on rport %v", t)) + wg.Wait() + } // end if Dialled successfully + delete((*hc.tuns), rport) } } // t.Ctl read loop logger.LogDebug("[ServerTun] Tunnel exiting t.Ctl read loop - channel closed??") + //wg.Wait() }() } diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index d6bc24e..5c1ea4a 100755 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -273,6 +273,7 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, // gracefully here if !strings.HasSuffix(inerr.Error(), "use of closed network connection") { log.Println(inerr) + conn.CollapseAllTunnels(true) os.Exit(1) } } @@ -310,6 +311,7 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, fmt.Println(outerr) _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // Best effort. log.Println("[Hanging up]") + conn.CollapseAllTunnels(true) os.Exit(0) } }() @@ -646,6 +648,7 @@ func main() { doShellMode(isInteractive, &conn, oldState, rec) } else { // copyMode _, s := doCopyMode(&conn, pathIsDest, fileArgs, rec) + conn.CollapseAllTunnels(true) rec.SetStatus(s) }