diff --git a/hkexnet.go b/hkexnet.go index e967868..81f8bc8 100644 --- a/hkexnet.go +++ b/hkexnet.go @@ -29,19 +29,28 @@ import ( ) const ( - csoNone = iota // No error, normal packet - csoHmacInvalid // HMAC mismatch detect on remote end - csoChaff // This packet is a dummy, do not process beyond decryption + CSONone = iota // No error, normal packet + CSOHmacInvalid // HMAC mismatch detected on remote end + CSOTermSize // set term size (rows:cols) + CSOChaff // Dummy packet, do not pass beyond decryption ) /*---------------------------------------------------------------------*/ +type WinSize struct { + Rows uint16 + Cols uint16 +} + // Conn is a HKex connection - a drop-in replacement for net.Conn type Conn struct { c net.Conn // which also implements io.Reader, io.Writer, ... h *HerraduraKEx - cipheropts uint32 // post-KEx cipher/hmac options - opts uint32 // post-KEx protocol options (caller-defined) + cipheropts uint32 // post-KEx cipher/hmac options + opts uint32 // post-KEx protocol options (caller-defined) + WinCh chan WinSize + Rows uint16 + Cols uint16 r cipher.Stream //read cipherStream rm hash.Hash w cipher.Stream //write cipherStream @@ -262,7 +271,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) { } log.Println("[Accepted]") - hc = Conn{c: c, h: New(0, 0), dBuf: new(bytes.Buffer)} + hc = Conn{c: c, h: New(0, 0), WinCh: make(chan WinSize, 1), + dBuf: new(bytes.Buffer)} // Read in hkexnet.Conn parameters over raw Conn c // d is value for Herradura key exchange @@ -311,9 +321,10 @@ func (c Conn) Read(b []byte) (n int, err error) { var hmacIn [4]uint8 var payloadLen uint32 - // Read ctrl/status opcode (csoHmacInvalid on hmac mismatch) + // Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch) err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp) - if ctrlStatOp == csoHmacInvalid { + log.Printf("[ctrlStatOp: %v]\n", ctrlStatOp) + if ctrlStatOp == CSOHmacInvalid { // Other side indicated channel tampering, close channel c.Close() return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") @@ -379,8 +390,12 @@ func (c Conn) Read(b []byte) (n int, err error) { } // Throw away pkt if it's chaff (ie., caller to Read() won't see this data) - if ctrlStatOp == csoChaff { + if ctrlStatOp == CSOChaff { log.Printf("[Chaff pkt]\n") + } else if ctrlStatOp == CSOTermSize { + fmt.Sscanf(string(payloadBytes), "%d %d", &c.Rows, &c.Cols) + log.Printf("[TermSize pkt: rows %v cols %v]\n", c.Rows, c.Cols) + c.WinCh <- WinSize{c.Rows, c.Cols} } else { c.dBuf.Write(payloadBytes) //log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes())) @@ -394,7 +409,7 @@ func (c Conn) Read(b []byte) (n int, err error) { // Log alert if hmac didn't match, corrupted channel if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **") - _, _ = c.c.Write([]byte{csoHmacInvalid}) + _, _ = c.c.Write([]byte{CSOHmacInvalid}) } } @@ -413,6 +428,11 @@ func (c Conn) Read(b []byte) (n int, err error) { // // See go doc io.Writer func (c Conn) Write(b []byte) (n int, err error) { + return c.WritePacket(b, CSONone) +} + +// Write a byte slice with specified ctrlStatusOp byte +func (c Conn) WritePacket(b []byte, op byte) (n int, err error) { //log.Printf("[Encrypting...]\r\n") var hmacOut []uint8 var payloadLen uint32 @@ -437,8 +457,7 @@ func (c Conn) Write(b []byte) (n int, err error) { } log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) - var ctrlStatOp byte - ctrlStatOp = csoNone + ctrlStatOp := op _ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp) // Write hmac LSB, payloadLen followed by payload diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index c90e20b..0de4ce2 100644 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -14,9 +14,12 @@ import ( "io/ioutil" "log" "os" + "os/exec" + "os/signal" "os/user" "strings" "sync" + "syscall" hkexsh "blitter.com/go/hkexsh" isatty "github.com/mattn/go-isatty" @@ -30,6 +33,23 @@ type cmdSpec struct { status int } +// get terminal size using 'stty' command +// (Most portable btwn Linux and MSYS/win32, but +// TODO: remove external dep on 'stty' utility) +func getTermSize() (rows int, cols int, err error) { + cmd := exec.Command("stty", "size") + cmd.Stdin = os.Stdin + out, err := cmd.Output() + //fmt.Printf("out: %#v\n", string(out)) + //fmt.Printf("err: %#v\n", err) + + fmt.Sscanf(string(out), "%d %d\n", &rows, &cols) + if err != nil { + log.Fatal(err) + } + return +} + // Demo of a simple client that dials up to a simple test server to // send data. // @@ -77,6 +97,9 @@ func main() { defer conn.Close() // From this point on, conn is a secure encrypted channel + rows := 0 + cols := 0 + // Set stdin in raw mode if it's an interactive session // TODO: send flag to server side indicating this // affects shell command used @@ -131,7 +154,9 @@ func main() { authCookie: []byte(authCookie), status: 0} - _, err = fmt.Fprintf(conn, "%d %d %d %d\n", len(rec.op), len(rec.who), len(rec.cmd), len(rec.authCookie)) + _, err = fmt.Fprintf(conn, "%d %d %d %d\n", + len(rec.op), len(rec.who), len(rec.cmd), len(rec.authCookie)) + _, err = conn.Write(rec.op) _, err = conn.Write(rec.who) _, err = conn.Write(rec.cmd) @@ -168,6 +193,27 @@ func main() { }() if isInteractive { + // Handle pty resizes (notify server side) + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGWINCH) + wg.Add(1) + go func() { + defer wg.Done() + + for range ch { + // Query client's term size so we can communicate it to server + // pty after interactive session starts + rows, cols, err = getTermSize() + log.Printf("[rows %v cols %v]\n", rows, cols) + if err != nil { + panic(err) + } + termSzPacket := fmt.Sprintf("%d %d", rows, cols) + conn.WritePacket([]byte(termSzPacket), hkexsh.CSOTermSize) + } + }() + ch <- syscall.SIGWINCH // Initial resize. + // client writer (to server) goroutine wg.Add(1) go func() { diff --git a/hkexshd/hkexshd.go b/hkexshd/hkexshd.go index d5906f9..eea06aa 100644 --- a/hkexshd/hkexshd.go +++ b/hkexshd/hkexshd.go @@ -28,6 +28,8 @@ type cmdSpec struct { who []byte cmd []byte authCookie []byte + termRows []byte + termCols []byte status int } @@ -119,6 +121,15 @@ func runShellAs(who string, cmd string, interactive bool, conn hkexsh.Conn) (err } // Make sure to close the pty at the end. defer func() { _ = ptmx.Close() }() // Best effort. + + // Watch for term resizes + go func() { + for sz := range conn.WinCh { + log.Printf("[Setting term size to: %v %v]\n", sz.Rows, sz.Cols) + pty.Setsize(ptmx, &pty.Winsize{Rows: sz.Rows, Cols: sz.Cols}) + } + }() + // Copy stdin to the pty.. (bgnd goroutine) go func() { _, _ = io.Copy(ptmx, conn) @@ -171,8 +182,8 @@ func main() { // Wait for a connection. conn, err := l.Accept() if err != nil { - log.Printf("Accept() got error(%v), hanging up.\n", err) - conn.Close() + log.Printf("Accept() got error(%v), hanging up.\n", err) + conn.Close() //log.Fatal(err) } else { log.Println("Accepted client")