From b45784e07b0fd05f45b1e6d4ca5da580c7b9471b Mon Sep 17 00:00:00 2001 From: Russ Magee Date: Sun, 15 Apr 2018 12:58:24 -0700 Subject: [PATCH] Minimal hmac channel verification w/close on tampering --- hkexnet.go | 34 ++++++++++++++++------------------ hkexsh/hkexsh.go | 5 ++++- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/hkexnet.go b/hkexnet.go index 63a318b..0405a86 100644 --- a/hkexnet.go +++ b/hkexnet.go @@ -298,10 +298,10 @@ func (c Conn) Read(b []byte) (n int, err error) { break } - var hmacIn uint8 + var hmacIn [4]uint8 var payloadLen uint32 - // Read the hmac LSB and payload len first + // Read the hmac and payload len first err = binary.Read(c.c, binary.BigEndian, &hmacIn) // Normal client 'exit' from interactive session will cause // (on server side) err.Error() == ": use of closed network connection" @@ -314,14 +314,6 @@ func (c Conn) Read(b []byte) (n int, err error) { return 0, err } - //if err != nil { - // if err.Error() != "EOF" { - // log.Println("Error was:", err.Error()) - // } else { - // return 0, err - // } - //} - err = binary.Read(c.c, binary.BigEndian, &payloadLen) if err != nil { if err.Error() != "EOF" { @@ -348,7 +340,7 @@ func (c Conn) Read(b []byte) (n int, err error) { } } - log.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) //EncodeToString(b[:n])) // print only used portion + log.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) db := bytes.NewBuffer(payloadBytes[:n]) //copying payloadBytes to db // The StreamReader acts like a pipe, decrypting @@ -358,7 +350,7 @@ func (c Conn) Read(b []byte) (n int, err error) { // The caller isn't necessarily reading the full payload so we need // to decrypt ot an intermediate buffer, draining it on demand of caller decryptN, err := rs.Read(payloadBytes) - log.Printf(" <-ptext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) //EncodeToString(b[:n])) + log.Printf(" <-ptext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) if err != nil { panic(err) } @@ -367,8 +359,14 @@ func (c Conn) Read(b []byte) (n int, err error) { // Re-calculate hmac, compare with received value c.rm.Write(payloadBytes) - hTmp := c.rm.Sum(nil)[0] - log.Printf("<%04x) HMAC:(i)%02x (c)%02x\r\n", decryptN, hmacIn, hTmp) + hTmp := c.rm.Sum(nil)[0:4] + log.Printf("<%04x) HMAC:(i)%s (c)%02x\r\n", decryptN, hex.EncodeToString([]byte(hmacIn[0:])), hTmp) + + // Puke if hmac didn't match, corrupted channel + if !bytes.Equal(hTmp, []byte(hmacIn[0:])) || hmacIn[0] > 0xf8 { + fmt.Println("** ALERT - hmac mismatch, possible channel tampering **") + c.Close() + } } retN := c.dBuf.Len() if retN > len(b) { @@ -386,18 +384,18 @@ func (c Conn) Read(b []byte) (n int, err error) { // See go doc io.Writer func (c Conn) Write(b []byte) (n int, err error) { //log.Printf("[Encrypting...]\r\n") - var hmacOut uint8 + var hmacOut []uint8 var payloadLen uint32 - log.Printf(" :>ptext:\r\n%s\r\n", hex.Dump(b)) //EncodeToString(b)) + log.Printf(" :>ptext:\r\n%s\r\n", hex.Dump(b)) payloadLen = uint32(len(b)) // Calculate hmac on payload c.wm.Write(b) - hmacOut = uint8(c.wm.Sum(nil)[0]) + hmacOut = c.wm.Sum(nil)[0:4] - log.Printf(" (%04x> HMAC(o):%02x\r\n", payloadLen, hmacOut) + log.Printf(" (%04x> HMAC(o):%s\r\n", payloadLen, hex.EncodeToString(hmacOut)) var wb bytes.Buffer // The StreamWriter acts like a pipe, forwarding whatever is diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index cd969b6..e0657d7 100644 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -80,8 +80,9 @@ func main() { // Set stdin in raw mode if it's an interactive session // TODO: send flag to server side indicating this // affects shell command used + var oldState *hkexsh.State if isatty.IsTerminal(os.Stdin.Fd()) { - oldState, err := hkexsh.MakeRaw(int(os.Stdin.Fd())) + oldState, err = hkexsh.MakeRaw(int(os.Stdin.Fd())) if err != nil { panic(err) } @@ -155,6 +156,7 @@ func main() { if inerr != nil { if inerr.Error() != "EOF" { fmt.Println(inerr) + _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // Best effort. os.Exit(1) } } @@ -178,6 +180,7 @@ func main() { log.Println(outerr) if outerr.Error() != "EOF" { fmt.Println(outerr) + _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // Best effort. os.Exit(2) } }