diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 297fccb..2c8a8e4 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -42,6 +42,7 @@ import ( "math/big" "math/rand" "net" + "os" "strings" "sync" "time" @@ -122,11 +123,8 @@ type ( dBuf *bytes.Buffer //decrypt buffer for Read() } - EscSeqs struct { - idx int - seqs []byte - outstr []string - } + EscHandler func(io.Writer) + EscSeqs map[byte]EscHandler ) var ( @@ -1136,60 +1134,40 @@ func Copy(dst io.Writer, src io.Reader) (written int64, err error) { return } -func escSeqScanner(s *EscSeqs, dst io.Writer, b byte) (passthru bool) { - passthru = true - if s.idx > 0 { - switch b { - case '~': - return - case s.seqs[0]: - dst.Write([]byte(s.outstr[0])) - passthru = false - b = '~' - case s.seqs[1]: - dst.Write([]byte(s.outstr[1])) - passthru = false - b = '~' - case s.seqs[2]: - dst.Write([]byte(s.outstr[2])) - passthru = false - b = '~' - } - } - - if b == '~' { - s.idx++ - } else { - s.idx = 0 - } - return -} - // copyBuffer is the actual implementation of Copy and CopyBuffer. // if buf is nil, one is allocated. func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { - escs := &EscSeqs{idx: 0, - seqs: []byte{ - 'i', - 't', - 'B', - }, - outstr: []string{ - "\x1b[s\x1b[2;1H\x1b[1;31m[HKEXSH]\x1b[39;49m\x1b[u", - "\x1b[1;32m[HKEXSH]\x1b[39;49m", - "\x1b[1;32m" + Bob + "\x1b[39;49m", - }, + // NOTE: using dst.Write() in these esc funcs will cause the output + // to function as a 'macro', outputting as if user typed the sequence. + // + // Using os.Stdout outputs to the client's term w/o it or the server + // 'seeing' the output. + // + // TODO: Devise a way to signal to main client thread that + // a goroutine should be spawned to do long-lived tasks for + // some esc sequences (eg., a time ticker in the corner of terminal, + // or tunnel traffic indicator - note we cannot just spawn a goroutine + // here, as copyBuffer() returns after each burst of data. Scope must + // outlive individual copyBuffer calls). + // (Note that since this custom copyBuffer func is used only by + // the hkexsh client, it should eventually be moved to client.) + escs := EscSeqs{ + 'i': func(io.Writer) { os.Stdout.Write([]byte("\x1b[s\x1b[2;1H\x1b[1;31m[HKEXSH]\x1b[39;49m\x1b[u")) }, + 't': func(io.Writer) { os.Stdout.Write([]byte("\x1b[1;32m[HKEXSH]\x1b[39;49m")) }, + 'B': func(io.Writer) { os.Stdout.Write([]byte("\x1b[1;32m" + Bob + "\x1b[39;49m")) }, } - // If the reader has a WriteTo method, use it to do the copy. - // Avoids an allocation and a copy. - if wt, ok := src.(io.WriterTo); ok { - return wt.WriteTo(dst) - } - // Similarly, if the writer has a ReadFrom method, use it to do the copy. - if rt, ok := dst.(io.ReaderFrom); ok { - return rt.ReadFrom(src) - } + /* + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) + } + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) + } + */ if buf == nil { size := 32 * 1024 if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N { @@ -1201,23 +1179,39 @@ func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err er } buf = make([]byte, size) } + + var seqPos int for { nr, er := src.Read(buf) if nr > 0 { // Look for sequences to trigger client-side diags - if escSeqScanner(escs, dst, buf[0]) { - nw, ew := dst.Write(buf[0:nr]) - if nw > 0 { - written += int64(nw) + // A repeat of 4 keys (conveniently 'dead' chars for most + // interactive shells; here CTRL-]) shall introduce + // some special responses or actions on the client side. + if seqPos < 4 { + if buf[0] == 0x1d { + seqPos++ } - if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break + } else /* seqPos > 0 */ { + if v, ok := escs[buf[0]]; ok { + v(dst) + nr-- + buf = buf[1:] } + seqPos = 0 + } + + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break } } if er != nil { diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index 065dc5c..26c4678 100755 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -249,7 +249,7 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, // pkg io/Copy expects EOF so normally this will // exit with inerr == nil - _, inerr := hkexnet.Copy(os.Stdout, conn) + _, inerr := io.Copy(os.Stdout, conn) if inerr != nil { _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // #nosec // Copy operations and user logging off will cause @@ -288,7 +288,7 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, _, outerr := func(conn *hkexnet.Conn, r io.Reader) (w int64, e error) { // Copy() expects EOF so this will // exit with outerr == nil - w, e = io.Copy(conn, r) + w, e = hkexnet.Copy(conn, r) return w, e }(conn, os.Stdin)