Added locking APIs for most Conn/Tun fields, save <- Data/ShutdownTun() race

Signed-off-by: Russ Magee <rmagee@gmail.com>
This commit is contained in:
Russ Magee 2019-06-27 22:10:59 -07:00
parent c327b2ec72
commit 8f5366fff4
2 changed files with 70 additions and 61 deletions

View file

@ -122,6 +122,14 @@ func Init(d bool, c string, f logger.Priority) {
_initLogging(d, c, f)
}
func (hc *Conn) Lock() {
hc.m.Lock()
}
func (hc *Conn) Unlock() {
hc.m.Unlock()
}
func (hc Conn) GetStatus() CSOType {
return *hc.closeStat
}
@ -1084,7 +1092,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true
hc.MarkTunDead(rport)
} else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport))
}
@ -1094,7 +1102,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true
hc.MarkTunDead(rport)
} else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport))
}
@ -1104,7 +1112,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true
hc.MarkTunDead(rport)
} else {
logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport))
}
@ -1117,7 +1125,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport))
}
(*hc.tuns)[rport].Data <- payloadBytes[4:]
(*hc.tuns)[rport].KeepAlive = 0
hc.ResetTunnelAge(rport)
} else {
logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport))
}
@ -1212,7 +1220,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) {
//
// Would be nice to determine if the mutex scope
// could be tightened.
hc.m.Lock()
hc.Lock()
payloadLen = uint32(len(b))
//!fmt.Printf(" --== payloadLen:%d\n", payloadLen)
if hc.logPlainText {
@ -1254,7 +1262,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) {
} else {
//fmt.Println("[a]WriteError!")
}
hc.m.Unlock()
hc.Unlock()
if err != nil {
log.Println(err)

View file

@ -16,6 +16,7 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"time"
"blitter.com/go/hkexsh/logger"
@ -46,7 +47,7 @@ type (
Lport uint16 // ... ie., RPort is on server, LPort is on client
Peer string //net.Addr
Died bool // set by client upon receipt of a CSOTunDisconn
KeepAlive uint // must be reset by client to keep server dial() alive
KeepAlive uint32 // must be reset by client to keep server dial() alive
Ctl chan rune //See TunCtl_* consts
Data chan []byte
}
@ -67,6 +68,8 @@ func (hc *Conn) CollapseAllTunnels(client bool) {
}
func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) {
hc.Lock()
defer hc.Unlock()
if (*hc.tuns) == nil {
(*hc.tuns) = make(map[uint16]*TunEndpoint)
}
@ -87,6 +90,7 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) {
// data channel removed on closure. Re-create it
(*hc.tuns)[rp].Data = make(chan []byte, 1)
}
(*hc.tuns)[rp].KeepAlive = 0
(*hc.tuns)[rp].Died = false
}
return
@ -149,37 +153,23 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
if e == io.EOF {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport]))
// if Died was already set, server-side already is gone.
if !(*hc.tuns)[rport].Died {
if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunHangup)
}
(*hc.tuns)[rport].Died = true
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
} else if strings.Contains(e.Error(), "i/o timeout") {
if (*hc.tuns)[rport].Died {
if !hc.TunIsAlive(rport) {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport]))
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
}
} else {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e))
if !(*hc.tuns)[rport].Died {
if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunHangup)
}
(*hc.tuns)[rport].Died = true
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
}
}
@ -232,7 +222,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
// When both workers have exited due to a disconnect or other
// condition, it's safe to remove the tunnel descriptor.
logger.LogDebug("[ClientTun] workers exited")
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
} // end for-accept
} // end Listen() block
}
@ -240,6 +230,39 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
}()
}
func (hc *Conn) AgeTunnel(endp uint16) uint32 {
return atomic.AddUint32(&(*hc.tuns)[endp].KeepAlive, 1)
}
func (hc *Conn) ResetTunnelAge(endp uint16) {
atomic.StoreUint32(&(*hc.tuns)[endp].KeepAlive, 0)
}
func (hc *Conn) TunIsAlive(endp uint16) bool {
hc.Lock()
defer hc.Unlock()
return !(*hc.tuns)[endp].Died
}
func (hc *Conn) MarkTunDead(endp uint16) {
hc.Lock()
defer hc.Unlock()
(*hc.tuns)[endp].Died = true
}
func (hc *Conn) ShutdownTun(endp uint16) {
hc.Lock()
defer hc.Unlock()
if (*hc.tuns)[endp] != nil {
(*hc.tuns)[endp].Died = true
if (*hc.tuns)[endp].Data != nil {
close((*hc.tuns)[endp].Data)
(*hc.tuns)[endp].Data = nil
}
}
delete((*hc.tuns), endp)
}
func (hc *Conn) StartServerTunnel(lport, rport uint16) {
hc.InitTunEndpoint(lport, "", rport)
var err error
@ -260,9 +283,9 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
logger.LogDebug("[ServerTun] worker A: Client endpoint removed.")
break
}
(*hc.tuns)[rport].KeepAlive += 1
if (*hc.tuns)[rport].KeepAlive > 25 {
(*hc.tuns)[rport].Died = true
age := hc.AgeTunnel(rport)
if age > 25 {
hc.MarkTunDead(rport)
logger.LogDebug("[ServerTun] worker A: Client died, hanging up.")
break
}
@ -319,37 +342,23 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
if e != nil {
if e == io.EOF {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport]))
if !(*hc.tuns)[rport].Died {
if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunDisconn)
}
(*hc.tuns)[rport].Died = true
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
} else if strings.Contains(e.Error(), "i/o timeout") {
if (*hc.tuns)[rport].Died {
if !hc.TunIsAlive(rport) {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport]))
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
}
} else {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e))
if !(*hc.tuns)[rport].Died {
if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunDisconn)
}
(*hc.tuns)[rport].Died = true
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
hc.ShutdownTun(rport)
break
}
}
@ -357,14 +366,6 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
rBuf = append(tunDst.Bytes(), rBuf[:n]...)
hc.WritePacket(rBuf[:n+4], CSOTunData)
}
//if (*hc.tuns)[rport].KeepAlive > 50000 {
// (*hc.tuns)[rport].Died = true
// logger.LogDebug("[ServerTun] worker A: Client died, hanging up.")
//} else {
// (*hc.tuns)[rport].KeepAlive += 1
//}
}
logger.LogDebug("[ServerTun] worker A: exiting")
}()
@ -382,7 +383,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
logger.LogDebug("[ServerTun] worker B: starting")
for {
rData, ok := <-(*hc.tuns)[rport].Data
rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race w/ShutdownTun() calls
if ok {
c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
_, e := c.Write(rData)