mirror of
				https://gogs.blitter.com/RLabs/xs
				synced 2024-08-14 10:26:42 +00:00 
			
		
		
		
	Added locking APIs for most Conn/Tun fields, save <- Data/ShutdownTun() race
Signed-off-by: Russ Magee <rmagee@gmail.com>
This commit is contained in:
		
							parent
							
								
									1057a78df3
								
							
						
					
					
						commit
						68b8c48e4d
					
				
					 2 changed files with 70 additions and 61 deletions
				
			
		|  | @ -122,6 +122,14 @@ func Init(d bool, c string, f logger.Priority) { | ||||||
| 	_initLogging(d, c, f) | 	_initLogging(d, c, f) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (hc *Conn) Lock() { | ||||||
|  | 		hc.m.Lock() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hc *Conn) Unlock() { | ||||||
|  | 		hc.m.Unlock() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (hc Conn) GetStatus() CSOType { | func (hc Conn) GetStatus() CSOType { | ||||||
| 	return *hc.closeStat | 	return *hc.closeStat | ||||||
| } | } | ||||||
|  | @ -1084,7 +1092,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { | ||||||
| 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | ||||||
| 				logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) | 				logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) | ||||||
| 				if _, ok := (*hc.tuns)[rport]; ok { | 				if _, ok := (*hc.tuns)[rport]; ok { | ||||||
| 					(*hc.tuns)[rport].Died = true | 					hc.MarkTunDead(rport) | ||||||
| 				} else { | 				} else { | ||||||
| 					logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) | 					logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) | ||||||
| 				} | 				} | ||||||
|  | @ -1094,7 +1102,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { | ||||||
| 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | ||||||
| 				logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) | 				logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) | ||||||
| 				if _, ok := (*hc.tuns)[rport]; ok { | 				if _, ok := (*hc.tuns)[rport]; ok { | ||||||
| 					(*hc.tuns)[rport].Died = true | 					hc.MarkTunDead(rport) | ||||||
| 				} else { | 				} else { | ||||||
| 					logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) | 					logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) | ||||||
| 				} | 				} | ||||||
|  | @ -1104,7 +1112,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { | ||||||
| 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | 				rport := binary.BigEndian.Uint16(payloadBytes[2:4]) | ||||||
| 				logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) | 				logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) | ||||||
| 				if _, ok := (*hc.tuns)[rport]; ok { | 				if _, ok := (*hc.tuns)[rport]; ok { | ||||||
| 					(*hc.tuns)[rport].Died = true | 					hc.MarkTunDead(rport) | ||||||
| 				} else { | 				} else { | ||||||
| 					logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) | 					logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) | ||||||
| 				} | 				} | ||||||
|  | @ -1117,7 +1125,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { | ||||||
| 						logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport)) | 						logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport)) | ||||||
| 					} | 					} | ||||||
| 					(*hc.tuns)[rport].Data <- payloadBytes[4:] | 					(*hc.tuns)[rport].Data <- payloadBytes[4:] | ||||||
| 					(*hc.tuns)[rport].KeepAlive = 0 | 					hc.ResetTunnelAge(rport) | ||||||
| 				} else { | 				} else { | ||||||
| 					logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) | 					logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) | ||||||
| 				} | 				} | ||||||
|  | @ -1212,7 +1220,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) { | ||||||
| 	// | 	// | ||||||
| 	// Would be nice to determine if the mutex scope | 	// Would be nice to determine if the mutex scope | ||||||
| 	// could be tightened. | 	// could be tightened. | ||||||
| 	hc.m.Lock() | 	hc.Lock() | ||||||
| 	payloadLen = uint32(len(b)) | 	payloadLen = uint32(len(b)) | ||||||
| 	//!fmt.Printf("  --== payloadLen:%d\n", payloadLen) | 	//!fmt.Printf("  --== payloadLen:%d\n", payloadLen) | ||||||
| 	if hc.logPlainText { | 	if hc.logPlainText { | ||||||
|  | @ -1254,7 +1262,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) { | ||||||
| 	} else { | 	} else { | ||||||
| 		//fmt.Println("[a]WriteError!") | 		//fmt.Println("[a]WriteError!") | ||||||
| 	} | 	} | ||||||
| 	hc.m.Unlock() | 	hc.Unlock() | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Println(err) | 		log.Println(err) | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"blitter.com/go/hkexsh/logger" | 	"blitter.com/go/hkexsh/logger" | ||||||
|  | @ -46,7 +47,7 @@ type ( | ||||||
| 		Lport     uint16    // ... ie., RPort is on server, LPort is on client | 		Lport     uint16    // ... ie., RPort is on server, LPort is on client | ||||||
| 		Peer      string    //net.Addr | 		Peer      string    //net.Addr | ||||||
| 		Died      bool      // set by client upon receipt of a CSOTunDisconn | 		Died      bool      // set by client upon receipt of a CSOTunDisconn | ||||||
| 		KeepAlive uint      // must be reset by client to keep server dial() alive | 		KeepAlive uint32    // must be reset by client to keep server dial() alive | ||||||
| 		Ctl       chan rune //See TunCtl_* consts | 		Ctl       chan rune //See TunCtl_* consts | ||||||
| 		Data      chan []byte | 		Data      chan []byte | ||||||
| 	} | 	} | ||||||
|  | @ -67,6 +68,8 @@ func (hc *Conn) CollapseAllTunnels(client bool) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { | func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { | ||||||
|  | 	hc.Lock() | ||||||
|  | 	defer hc.Unlock() | ||||||
| 	if (*hc.tuns) == nil { | 	if (*hc.tuns) == nil { | ||||||
| 		(*hc.tuns) = make(map[uint16]*TunEndpoint) | 		(*hc.tuns) = make(map[uint16]*TunEndpoint) | ||||||
| 	} | 	} | ||||||
|  | @ -87,6 +90,7 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { | ||||||
| 			// data channel removed on closure. Re-create it | 			// data channel removed on closure. Re-create it | ||||||
| 			(*hc.tuns)[rp].Data = make(chan []byte, 1) | 			(*hc.tuns)[rp].Data = make(chan []byte, 1) | ||||||
| 		} | 		} | ||||||
|  | 		(*hc.tuns)[rp].KeepAlive = 0 | ||||||
| 		(*hc.tuns)[rp].Died = false | 		(*hc.tuns)[rp].Died = false | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
|  | @ -149,37 +153,23 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { | ||||||
| 										if e == io.EOF { | 										if e == io.EOF { | ||||||
| 											logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) | 											logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) | ||||||
| 											// if Died was already set, server-side already is gone. | 											// if Died was already set, server-side already is gone. | ||||||
| 											if !(*hc.tuns)[rport].Died { | 											if hc.TunIsAlive(rport) { | ||||||
| 												hc.WritePacket(tunDst.Bytes(), CSOTunHangup) | 												hc.WritePacket(tunDst.Bytes(), CSOTunHangup) | ||||||
| 											} | 											} | ||||||
| 											(*hc.tuns)[rport].Died = true | 											hc.ShutdownTun(rport) | ||||||
| 											if (*hc.tuns)[rport].Data != nil { |  | ||||||
| 												close((*hc.tuns)[rport].Data) |  | ||||||
| 												(*hc.tuns)[rport].Data = nil |  | ||||||
| 											} |  | ||||||
| 											delete((*hc.tuns), rport) |  | ||||||
| 											break | 											break | ||||||
| 										} else if strings.Contains(e.Error(), "i/o timeout") { | 										} else if strings.Contains(e.Error(), "i/o timeout") { | ||||||
| 											if (*hc.tuns)[rport].Died { | 											if !hc.TunIsAlive(rport) { | ||||||
| 												logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) | 												logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) | ||||||
| 												if (*hc.tuns)[rport].Data != nil { | 												hc.ShutdownTun(rport) | ||||||
| 													close((*hc.tuns)[rport].Data) |  | ||||||
| 													(*hc.tuns)[rport].Data = nil |  | ||||||
| 												} |  | ||||||
| 												delete((*hc.tuns), rport) |  | ||||||
| 												break | 												break | ||||||
| 											} | 											} | ||||||
| 										} else { | 										} else { | ||||||
| 											logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e)) | 											logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e)) | ||||||
| 											if !(*hc.tuns)[rport].Died { | 											if hc.TunIsAlive(rport) { | ||||||
| 												hc.WritePacket(tunDst.Bytes(), CSOTunHangup) | 												hc.WritePacket(tunDst.Bytes(), CSOTunHangup) | ||||||
| 											} | 											} | ||||||
| 											(*hc.tuns)[rport].Died = true | 											hc.ShutdownTun(rport) | ||||||
| 											if (*hc.tuns)[rport].Data != nil { |  | ||||||
| 												close((*hc.tuns)[rport].Data) |  | ||||||
| 												(*hc.tuns)[rport].Data = nil |  | ||||||
| 											} |  | ||||||
| 											delete((*hc.tuns), rport) |  | ||||||
| 											break | 											break | ||||||
| 										} | 										} | ||||||
| 									} | 									} | ||||||
|  | @ -232,7 +222,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { | ||||||
| 						// When both workers have exited due to a disconnect or other | 						// When both workers have exited due to a disconnect or other | ||||||
| 						// condition, it's safe to remove the tunnel descriptor. | 						// condition, it's safe to remove the tunnel descriptor. | ||||||
| 						logger.LogDebug("[ClientTun] workers exited") | 						logger.LogDebug("[ClientTun] workers exited") | ||||||
| 						delete((*hc.tuns), rport) | 						hc.ShutdownTun(rport) | ||||||
| 					} // end for-accept | 					} // end for-accept | ||||||
| 				} // end Listen() block | 				} // end Listen() block | ||||||
| 			} | 			} | ||||||
|  | @ -240,6 +230,39 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { | ||||||
| 	}() | 	}() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (hc *Conn) AgeTunnel(endp uint16) uint32 { | ||||||
|  | 	return atomic.AddUint32(&(*hc.tuns)[endp].KeepAlive, 1) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hc *Conn) ResetTunnelAge(endp uint16) { | ||||||
|  | 	atomic.StoreUint32(&(*hc.tuns)[endp].KeepAlive, 0) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hc *Conn) TunIsAlive(endp uint16) bool { | ||||||
|  | 	hc.Lock() | ||||||
|  | 	defer hc.Unlock() | ||||||
|  | 	return !(*hc.tuns)[endp].Died | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hc *Conn) MarkTunDead(endp uint16) { | ||||||
|  | 	hc.Lock() | ||||||
|  | 	defer hc.Unlock() | ||||||
|  | 	(*hc.tuns)[endp].Died = true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hc *Conn) ShutdownTun(endp uint16) { | ||||||
|  | 	hc.Lock() | ||||||
|  | 	defer hc.Unlock() | ||||||
|  | 	if (*hc.tuns)[endp] != nil { | ||||||
|  | 		(*hc.tuns)[endp].Died = true | ||||||
|  | 		if (*hc.tuns)[endp].Data != nil { | ||||||
|  | 			close((*hc.tuns)[endp].Data) | ||||||
|  | 			(*hc.tuns)[endp].Data = nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	delete((*hc.tuns), endp) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (hc *Conn) StartServerTunnel(lport, rport uint16) { | func (hc *Conn) StartServerTunnel(lport, rport uint16) { | ||||||
| 	hc.InitTunEndpoint(lport, "", rport) | 	hc.InitTunEndpoint(lport, "", rport) | ||||||
| 	var err error | 	var err error | ||||||
|  | @ -260,9 +283,9 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { | ||||||
| 					logger.LogDebug("[ServerTun] worker A: Client endpoint removed.") | 					logger.LogDebug("[ServerTun] worker A: Client endpoint removed.") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				(*hc.tuns)[rport].KeepAlive += 1 | 				age := hc.AgeTunnel(rport) | ||||||
| 				if (*hc.tuns)[rport].KeepAlive > 25 { | 				if age > 25 { | ||||||
| 					(*hc.tuns)[rport].Died = true | 					hc.MarkTunDead(rport) | ||||||
| 					logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") | 					logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
|  | @ -319,37 +342,23 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { | ||||||
| 							if e != nil { | 							if e != nil { | ||||||
| 								if e == io.EOF { | 								if e == io.EOF { | ||||||
| 									logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) | 									logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) | ||||||
| 									if !(*hc.tuns)[rport].Died { | 									if hc.TunIsAlive(rport) { | ||||||
| 										hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) | 										hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) | ||||||
| 									} | 									} | ||||||
| 									(*hc.tuns)[rport].Died = true | 									hc.ShutdownTun(rport) | ||||||
| 									if (*hc.tuns)[rport].Data != nil { |  | ||||||
| 										close((*hc.tuns)[rport].Data) |  | ||||||
| 										(*hc.tuns)[rport].Data = nil |  | ||||||
| 									} |  | ||||||
| 									delete((*hc.tuns), rport) |  | ||||||
| 									break | 									break | ||||||
| 								} else if strings.Contains(e.Error(), "i/o timeout") { | 								} else if strings.Contains(e.Error(), "i/o timeout") { | ||||||
| 									if (*hc.tuns)[rport].Died { | 									if !hc.TunIsAlive(rport) { | ||||||
| 										logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) | 											logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) | ||||||
| 										if (*hc.tuns)[rport].Data != nil { | 											hc.ShutdownTun(rport) | ||||||
| 											close((*hc.tuns)[rport].Data) |  | ||||||
| 											(*hc.tuns)[rport].Data = nil |  | ||||||
| 										} |  | ||||||
| 										delete((*hc.tuns), rport) |  | ||||||
| 										break | 										break | ||||||
| 									} | 									} | ||||||
| 								} else { | 								} else { | ||||||
| 									logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e)) | 									logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e)) | ||||||
| 									if !(*hc.tuns)[rport].Died { | 									if hc.TunIsAlive(rport) { | ||||||
| 										hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) | 										hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) | ||||||
| 									} | 									} | ||||||
| 									(*hc.tuns)[rport].Died = true | 									hc.ShutdownTun(rport) | ||||||
| 									if (*hc.tuns)[rport].Data != nil { |  | ||||||
| 										close((*hc.tuns)[rport].Data) |  | ||||||
| 										(*hc.tuns)[rport].Data = nil |  | ||||||
| 									} |  | ||||||
| 									delete((*hc.tuns), rport) |  | ||||||
| 									break | 									break | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
|  | @ -357,14 +366,6 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { | ||||||
| 								rBuf = append(tunDst.Bytes(), rBuf[:n]...) | 								rBuf = append(tunDst.Bytes(), rBuf[:n]...) | ||||||
| 								hc.WritePacket(rBuf[:n+4], CSOTunData) | 								hc.WritePacket(rBuf[:n+4], CSOTunData) | ||||||
| 							} | 							} | ||||||
| 
 |  | ||||||
| 							//if (*hc.tuns)[rport].KeepAlive > 50000 { |  | ||||||
| 							//	(*hc.tuns)[rport].Died = true |  | ||||||
| 							//	logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") |  | ||||||
| 							//} else { |  | ||||||
| 							//	(*hc.tuns)[rport].KeepAlive += 1 |  | ||||||
| 							//} |  | ||||||
| 
 |  | ||||||
| 						} | 						} | ||||||
| 						logger.LogDebug("[ServerTun] worker A: exiting") | 						logger.LogDebug("[ServerTun] worker A: exiting") | ||||||
| 					}() | 					}() | ||||||
|  | @ -382,7 +383,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { | ||||||
| 
 | 
 | ||||||
| 						logger.LogDebug("[ServerTun] worker B: starting") | 						logger.LogDebug("[ServerTun] worker B: starting") | ||||||
| 						for { | 						for { | ||||||
| 							rData, ok := <-(*hc.tuns)[rport].Data | 							rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race w/ShutdownTun() calls | ||||||
| 							if ok { | 							if ok { | ||||||
| 								c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) | 								c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) | ||||||
| 								_, e := c.Write(rData) | 								_, e := c.Write(rData) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue