diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 21da20f..852d8c5 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -122,6 +122,14 @@ func Init(d bool, c string, f logger.Priority) { _initLogging(d, c, f) } +func (hc *Conn) Lock() { + hc.m.Lock() +} + +func (hc *Conn) Unlock() { + hc.m.Unlock() +} + func (hc Conn) GetStatus() CSOType { return *hc.closeStat } @@ -1084,7 +1092,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) if _, ok := (*hc.tuns)[rport]; ok { - (*hc.tuns)[rport].Died = true + hc.MarkTunDead(rport) } else { logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) } @@ -1094,7 +1102,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) if _, ok := (*hc.tuns)[rport]; ok { - (*hc.tuns)[rport].Died = true + hc.MarkTunDead(rport) } else { logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) } @@ -1104,7 +1112,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) if _, ok := (*hc.tuns)[rport]; ok { - (*hc.tuns)[rport].Died = true + hc.MarkTunDead(rport) } else { logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) } @@ -1117,7 +1125,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport)) } (*hc.tuns)[rport].Data <- payloadBytes[4:] - (*hc.tuns)[rport].KeepAlive = 0 + hc.ResetTunnelAge(rport) } else { logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) } @@ -1212,7 +1220,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) { // // Would be nice to determine if the mutex scope // could be tightened. - hc.m.Lock() + hc.Lock() payloadLen = uint32(len(b)) //!fmt.Printf(" --== payloadLen:%d\n", payloadLen) if hc.logPlainText { @@ -1254,7 +1262,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) { } else { //fmt.Println("[a]WriteError!") } - hc.m.Unlock() + hc.Unlock() if err != nil { log.Println(err) diff --git a/hkexnet/hkextun.go b/hkexnet/hkextun.go index 87ef049..38fbd0f 100644 --- a/hkexnet/hkextun.go +++ b/hkexnet/hkextun.go @@ -16,6 +16,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "blitter.com/go/hkexsh/logger" @@ -46,7 +47,7 @@ type ( Lport uint16 // ... ie., RPort is on server, LPort is on client Peer string //net.Addr Died bool // set by client upon receipt of a CSOTunDisconn - KeepAlive uint // must be reset by client to keep server dial() alive + KeepAlive uint32 // must be reset by client to keep server dial() alive Ctl chan rune //See TunCtl_* consts Data chan []byte } @@ -67,6 +68,8 @@ func (hc *Conn) CollapseAllTunnels(client bool) { } func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { + hc.Lock() + defer hc.Unlock() if (*hc.tuns) == nil { (*hc.tuns) = make(map[uint16]*TunEndpoint) } @@ -87,6 +90,7 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { // data channel removed on closure. Re-create it (*hc.tuns)[rp].Data = make(chan []byte, 1) } + (*hc.tuns)[rp].KeepAlive = 0 (*hc.tuns)[rp].Died = false } return @@ -149,37 +153,23 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { if e == io.EOF { logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) // if Died was already set, server-side already is gone. - if !(*hc.tuns)[rport].Died { + if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunHangup) } - (*hc.tuns)[rport].Died = true - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) break } else if strings.Contains(e.Error(), "i/o timeout") { - if (*hc.tuns)[rport].Died { + if !hc.TunIsAlive(rport) { logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) break } } else { logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e)) - if !(*hc.tuns)[rport].Died { + if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunHangup) } - (*hc.tuns)[rport].Died = true - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) break } } @@ -232,7 +222,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { // When both workers have exited due to a disconnect or other // condition, it's safe to remove the tunnel descriptor. logger.LogDebug("[ClientTun] workers exited") - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) } // end for-accept } // end Listen() block } @@ -240,6 +230,39 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { }() } +func (hc *Conn) AgeTunnel(endp uint16) uint32 { + return atomic.AddUint32(&(*hc.tuns)[endp].KeepAlive, 1) +} + +func (hc *Conn) ResetTunnelAge(endp uint16) { + atomic.StoreUint32(&(*hc.tuns)[endp].KeepAlive, 0) +} + +func (hc *Conn) TunIsAlive(endp uint16) bool { + hc.Lock() + defer hc.Unlock() + return !(*hc.tuns)[endp].Died +} + +func (hc *Conn) MarkTunDead(endp uint16) { + hc.Lock() + defer hc.Unlock() + (*hc.tuns)[endp].Died = true +} + +func (hc *Conn) ShutdownTun(endp uint16) { + hc.Lock() + defer hc.Unlock() + if (*hc.tuns)[endp] != nil { + (*hc.tuns)[endp].Died = true + if (*hc.tuns)[endp].Data != nil { + close((*hc.tuns)[endp].Data) + (*hc.tuns)[endp].Data = nil + } + } + delete((*hc.tuns), endp) +} + func (hc *Conn) StartServerTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) var err error @@ -260,9 +283,9 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug("[ServerTun] worker A: Client endpoint removed.") break } - (*hc.tuns)[rport].KeepAlive += 1 - if (*hc.tuns)[rport].KeepAlive > 25 { - (*hc.tuns)[rport].Died = true + age := hc.AgeTunnel(rport) + if age > 25 { + hc.MarkTunDead(rport) logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") break } @@ -319,37 +342,23 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { if e != nil { if e == io.EOF { logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) - if !(*hc.tuns)[rport].Died { + if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) } - (*hc.tuns)[rport].Died = true - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) break } else if strings.Contains(e.Error(), "i/o timeout") { - if (*hc.tuns)[rport].Died { - logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + if !hc.TunIsAlive(rport) { + logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) + hc.ShutdownTun(rport) break } } else { logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e)) - if !(*hc.tuns)[rport].Died { + if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) } - (*hc.tuns)[rport].Died = true - if (*hc.tuns)[rport].Data != nil { - close((*hc.tuns)[rport].Data) - (*hc.tuns)[rport].Data = nil - } - delete((*hc.tuns), rport) + hc.ShutdownTun(rport) break } } @@ -357,14 +366,6 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { rBuf = append(tunDst.Bytes(), rBuf[:n]...) hc.WritePacket(rBuf[:n+4], CSOTunData) } - - //if (*hc.tuns)[rport].KeepAlive > 50000 { - // (*hc.tuns)[rport].Died = true - // logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") - //} else { - // (*hc.tuns)[rport].KeepAlive += 1 - //} - } logger.LogDebug("[ServerTun] worker A: exiting") }() @@ -382,7 +383,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug("[ServerTun] worker B: starting") for { - rData, ok := <-(*hc.tuns)[rport].Data + rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race w/ShutdownTun() calls if ok { c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) _, e := c.Write(rData)