diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 852d8c5..91d8c4e 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -123,11 +123,11 @@ func Init(d bool, c string, f logger.Priority) { } func (hc *Conn) Lock() { - hc.m.Lock() + hc.m.Lock() } func (hc *Conn) Unlock() { - hc.m.Unlock() + hc.m.Unlock() } func (hc Conn) GetStatus() CSOType { @@ -1136,7 +1136,9 @@ func (hc Conn) Read(b []byte) (n int, err error) { _ = binary.BigEndian.Uint16(payloadBytes[0:2]) //logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunKeepAlive")) for _, t := range *hc.tuns { + hc.Lock() t.KeepAlive = 0 + hc.Unlock() } } else if ctrlStatOp == CSONone { hc.dBuf.Write(payloadBytes) diff --git a/hkexnet/hkextun.go b/hkexnet/hkextun.go index 1f1133c..2709aa2 100644 --- a/hkexnet/hkextun.go +++ b/hkexnet/hkextun.go @@ -16,7 +16,6 @@ import ( "net" "strings" "sync" - "sync/atomic" "time" "blitter.com/go/hkexsh/logger" @@ -156,7 +155,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunHangup) } - hc.ShutdownTun(rport) + hc.ShutdownTun(rport) // FIXME: race-C break } else if strings.Contains(e.Error(), "i/o timeout") { if !hc.TunIsAlive(rport) { @@ -200,7 +199,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { logger.LogDebug("[ClientTun] worker B: starting") for { - bytes, ok := <-(*hc.tuns)[rport].Data + bytes, ok := <-(*hc.tuns)[rport].Data // FIXME: race-C w/ShutdownTun calls if ok { c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) _, e := c.Write(bytes) @@ -231,11 +230,22 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { } func (hc *Conn) AgeTunnel(endp uint16) uint32 { - return atomic.AddUint32(&(*hc.tuns)[endp].KeepAlive, 1) + hc.Lock() + defer hc.Unlock() + (*hc.tuns)[endp].KeepAlive += 1 + return (*hc.tuns)[endp].KeepAlive } func (hc *Conn) ResetTunnelAge(endp uint16) { - atomic.StoreUint32(&(*hc.tuns)[endp].KeepAlive, 0) + hc.Lock() + defer hc.Unlock() + (*hc.tuns)[endp].KeepAlive = 0 +} + +func (hc *Conn) TunIsNil(endp uint16) bool { + hc.Lock() + defer hc.Unlock() + return (*hc.tuns)[endp] == nil } func (hc *Conn) TunIsAlive(endp uint16) bool { @@ -279,7 +289,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { defer wg.Done() for { time.Sleep(100 * time.Millisecond) - if (*hc.tuns)[rport] == nil { + if hc.TunIsNil(rport) { logger.LogDebug("[ServerTun] worker A: Client endpoint removed.") break } @@ -297,7 +307,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug(fmt.Sprintf("[ServerTun] got Ctl '%c'.", cmd)) if cmd == 'd' { // if re-using tunnel, re-init it - if (*hc.tuns)[rport] == nil { + if hc.TunIsNil(rport) { hc.InitTunEndpoint(lport, "", rport) } logger.LogDebug("[ServerTun] dialling...") @@ -345,12 +355,12 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) } - hc.ShutdownTun(rport) + hc.ShutdownTun(rport) // FIXME: race-A break } else if strings.Contains(e.Error(), "i/o timeout") { if !hc.TunIsAlive(rport) { logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) - hc.ShutdownTun(rport) + hc.ShutdownTun(rport) // FIXME: race-B break } } else { @@ -358,7 +368,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { if hc.TunIsAlive(rport) { hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) } - hc.ShutdownTun(rport) + hc.ShutdownTun(rport) // FIXME: race-C break } } @@ -383,7 +393,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { logger.LogDebug("[ServerTun] worker B: starting") for { - rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race w/ShutdownTun() calls + rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race-A, race-B, race-C (w/ShutdownTun() calls) if ok { c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) _, e := c.Write(rData)