-Added error checking for all stages of hkex.Conn.Accept() and GetStream()

-Server will log such errors without panic/exit
-Const added but not yet used for 'chaff' packets
This commit is contained in:
Russ Magee 2018-04-28 16:05:33 -07:00
parent c56d4d9ad9
commit 50f0433579
5 changed files with 132 additions and 110 deletions

View file

@ -15,11 +15,11 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"hash" "hash"
"log" "log"
"math/big" "math/big"
"os"
"golang.org/x/crypto/blowfish" "golang.org/x/crypto/blowfish"
"golang.org/x/crypto/twofish" "golang.org/x/crypto/twofish"
@ -46,12 +46,11 @@ const (
/* Support functionality to set up encryption after a channel has /* Support functionality to set up encryption after a channel has
been negotiated via hkexnet.go been negotiated via hkexnet.go
*/ */
func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) { func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash, err error) {
var key []byte var key []byte
var block cipher.Block var block cipher.Block
var iv []byte var iv []byte
var ivlen int var ivlen int
var err error
copts := hc.cipheropts & 0xFF copts := hc.cipheropts & 0xFF
// TODO: each cipher alg case should ensure len(keymat.Bytes()) // TODO: each cipher alg case should ensure len(keymat.Bytes())
@ -93,7 +92,8 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default: default:
log.Printf("[invalid cipher (%d)]\n", copts) log.Printf("[invalid cipher (%d)]\n", copts)
fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts) fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts)
os.Exit(1) err = errors.New("hkexchan: INVALID CIPHER ALG")
//os.Exit(1)
} }
hopts := (hc.cipheropts >> 8) & 0xFF hopts := (hc.cipheropts >> 8) & 0xFF
@ -109,20 +109,19 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default: default:
log.Printf("[invalid hmac (%d)]\n", hopts) log.Printf("[invalid hmac (%d)]\n", hopts)
fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts) fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts)
os.Exit(1) err = errors.New("hkexchan: INVALID HMAC ALG")
return
//os.Exit(1)
} }
if err != nil { if err != nil {
panic(err) // Feed the IV into the hmac: all traffic in the connection must
// feed its data into the hmac afterwards, so both ends can xor
// that with the stream to detect corruption.
_, _ = mc.Write(iv)
var currentHash []byte
currentHash = mc.Sum(currentHash)
log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash))
} }
// Feed the IV into the hmac: all traffic in the connection must
// feed its data into the hmac afterwards, so both ends can xor
// that with the stream to detect corruption.
_, _ = mc.Write(iv)
var currentHash []byte
currentHash = mc.Sum(currentHash)
log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash))
return return
} }

View file

@ -28,6 +28,12 @@ import (
"time" "time"
) )
const (
csoNone = iota // No error, normal packet
csoHmacInvalid // HMAC mismatch detect on remote end
csoChaff // This packet is a dummy, do not process beyond decryption
)
/*---------------------------------------------------------------------*/ /*---------------------------------------------------------------------*/
// Conn is a HKex connection - a drop-in replacement for net.Conn // Conn is a HKex connection - a drop-in replacement for net.Conn
@ -149,8 +155,8 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e
hc.h.FA() hc.h.FA()
log.Printf("**(c)** FA:%s\n", hc.h.fa) log.Printf("**(c)** FA:%s\n", hc.h.fa)
hc.r, hc.rm = hc.getStream(hc.h.fa) hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa) hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return return
} }
@ -262,11 +268,13 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
// d is value for Herradura key exchange // d is value for Herradura key exchange
d := big.NewInt(0) d := big.NewInt(0)
_, err = fmt.Fscanln(c, d) _, err = fmt.Fscanln(c, d)
log.Printf("[Got d:%v]", d)
if err != nil { if err != nil {
return hc, err return hc, err
} }
_, err = fmt.Fscanf(c, "%08x:%08x\n", _, err = fmt.Fscanf(c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts) &hc.cipheropts, &hc.opts)
log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts)
if err != nil { if err != nil {
return hc, err return hc, err
} }
@ -279,8 +287,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16), fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16),
hc.cipheropts, hc.opts) hc.cipheropts, hc.opts)
hc.r, hc.rm = hc.getStream(hc.h.fa) hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa) hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return return
} }
@ -303,9 +311,9 @@ func (c Conn) Read(b []byte) (n int, err error) {
var hmacIn [4]uint8 var hmacIn [4]uint8
var payloadLen uint32 var payloadLen uint32
// Read ctrl/status opcode (for now, set nonzero on hmac mismatch) // Read ctrl/status opcode (csoHmacInvalid on hmac mismatch)
err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp) err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp)
if ctrlStatOp != 0 { if ctrlStatOp == csoHmacInvalid {
// Other side indicated channel tampering, close channel // Other side indicated channel tampering, close channel
c.Close() c.Close()
return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **")
@ -328,14 +336,19 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil { if err != nil {
if err.Error() != "EOF" { if err.Error() != "EOF" {
panic(err) panic(err)
} // else { // Cannot just return 0, err here - client won't hang up properly
// return 0, err // when 'exit' from shell. TODO: try server sending ctrlStatOp to
//} // indicate to Reader? -rlm 20180428
}
} }
if payloadLen > 16384 { if payloadLen > 16384 {
panic("Insane payloadLen") log.Printf("[Insane payloadLen:%v]\n", payloadLen)
c.Close()
return 1, errors.New("Insane payloadLen")
} }
//log.Println("payloadLen:", payloadLen) //log.Println("payloadLen:", payloadLen)
var payloadBytes = make([]byte, payloadLen) var payloadBytes = make([]byte, payloadLen)
n, err = io.ReadFull(c.c, payloadBytes) n, err = io.ReadFull(c.c, payloadBytes)
//log.Print(" << Read ", n, " payloadBytes") //log.Print(" << Read ", n, " payloadBytes")
@ -364,8 +377,14 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
c.dBuf.Write(payloadBytes)
//log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes())) // Throw away pkt if it's chaff (ie., caller to Read() won't see this data)
if ctrlStatOp == csoChaff {
log.Printf("[Chaff pkt]\n")
} else {
c.dBuf.Write(payloadBytes)
//log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes()))
}
// Re-calculate hmac, compare with received value // Re-calculate hmac, compare with received value
c.rm.Write(payloadBytes) c.rm.Write(payloadBytes)
@ -374,10 +393,11 @@ func (c Conn) Read(b []byte) (n int, err error) {
// Log alert if hmac didn't match, corrupted channel // Log alert if hmac didn't match, corrupted channel
if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ {
fmt.Println("** ALERT - hmac mismatch, possible channel tampering **") fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **")
_, _ = c.c.Write([]byte{0x1}) _, _ = c.c.Write([]byte{csoHmacInvalid})
} }
} }
retN := c.dBuf.Len() retN := c.dBuf.Len()
if retN > len(b) { if retN > len(b) {
retN = len(b) retN = len(b)
@ -418,7 +438,7 @@ func (c Conn) Write(b []byte) (n int, err error) {
log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes()))
var ctrlStatOp byte var ctrlStatOp byte
ctrlStatOp = 0x00 ctrlStatOp = csoNone
_ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp) _ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp)
// Write hmac LSB, payloadLen followed by payload // Write hmac LSB, payloadLen followed by payload

View file

@ -19,7 +19,7 @@ import (
"os/user" "os/user"
"github.com/jameskeane/bcrypt" "github.com/jameskeane/bcrypt"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
) )
func main() { func main() {

View file

@ -18,7 +18,7 @@ import (
"strings" "strings"
"sync" "sync"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
isatty "github.com/mattn/go-isatty" isatty "github.com/mattn/go-isatty"
) )

View file

@ -18,8 +18,8 @@ import (
"os/user" "os/user"
"syscall" "syscall"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
"blitter.com/hkexsh/spinsult" "blitter.com/go/hkexsh/spinsult"
"github.com/kr/pty" "github.com/kr/pty"
) )
@ -171,90 +171,93 @@ func main() {
// Wait for a connection. // Wait for a connection.
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
log.Fatal(err) log.Printf("Accept() got error(%v), hanging up.\n", err)
} conn.Close()
log.Println("Accepted client") //log.Fatal(err)
} else {
log.Println("Accepted client")
// Handle the connection in a new goroutine. // Handle the connection in a new goroutine.
// The loop then returns to accepting, so that // The loop then returns to accepting, so that
// multiple connections may be served concurrently. // multiple connections may be served concurrently.
go func(c hkexsh.Conn) (e error) { go func(c hkexsh.Conn) (e error) {
defer c.Close() defer c.Close()
//We use io.ReadFull() here to guarantee we consume //We use io.ReadFull() here to guarantee we consume
//just the data we want for the cmdSpec, and no more. //just the data we want for the cmdSpec, and no more.
//Otherwise data will be sitting in the channel that isn't //Otherwise data will be sitting in the channel that isn't
//passed down to the command handlers. //passed down to the command handlers.
var rec cmdSpec var rec cmdSpec
var len1, len2, len3, len4 uint32 var len1, len2, len3, len4 uint32
n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4) n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4)
log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4) log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4)
if err != nil || n < 4 { if err != nil || n < 4 {
log.Println("[Bad cmdSpec fmt]") log.Println("[Bad cmdSpec fmt]")
return err return err
} }
//fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4) //fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4)
rec.op = make([]byte, len1, len1) rec.op = make([]byte, len1, len1)
_, err = io.ReadFull(c, rec.op) _, err = io.ReadFull(c, rec.op)
if err != nil { if err != nil {
log.Println("[Bad cmdSpec.op]") log.Println("[Bad cmdSpec.op]")
return err return err
} }
rec.who = make([]byte, len2, len2) rec.who = make([]byte, len2, len2)
_, err = io.ReadFull(c, rec.who) _, err = io.ReadFull(c, rec.who)
if err != nil { if err != nil {
log.Println("[Bad cmdSpec.who]") log.Println("[Bad cmdSpec.who]")
return err return err
} }
rec.cmd = make([]byte, len3, len3) rec.cmd = make([]byte, len3, len3)
_, err = io.ReadFull(c, rec.cmd) _, err = io.ReadFull(c, rec.cmd)
if err != nil { if err != nil {
log.Println("[Bad cmdSpec.cmd]") log.Println("[Bad cmdSpec.cmd]")
return err return err
} }
rec.authCookie = make([]byte, len4, len4) rec.authCookie = make([]byte, len4, len4)
_, err = io.ReadFull(c, rec.authCookie) _, err = io.ReadFull(c, rec.authCookie)
if err != nil { if err != nil {
log.Println("[Bad cmdSpec.authCookie]") log.Println("[Bad cmdSpec.authCookie]")
return err return err
} }
log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n", log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n",
rec.op[0], string(rec.who), string(rec.cmd)) rec.op[0], string(rec.who), string(rec.cmd))
valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd") valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd")
if !valid { if !valid {
log.Println("Invalid user", string(rec.who)) log.Println("Invalid user", string(rec.who))
c.Write([]byte(rejectUserMsg())) c.Write([]byte(rejectUserMsg()))
return
}
log.Printf("[allowedCmds:%s]\n", allowedCmds)
if rec.op[0] == 'c' {
// Non-interactive command
log.Println("[Running command]")
runShellAs(string(rec.who), string(rec.cmd), false, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Command complete]")
} else if rec.op[0] == 's' {
log.Println("[Running shell]")
runShellAs(string(rec.who), string(rec.cmd), true, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Exiting shell]")
} else {
log.Println("[Bad cmdSpec]")
}
return return
} }(conn)
log.Printf("[allowedCmds:%s]\n", allowedCmds) } // Accept() success
if rec.op[0] == 'c' {
// Non-interactive command
log.Println("[Running command]")
runShellAs(string(rec.who), string(rec.cmd), false, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Command complete]")
} else if rec.op[0] == 's' {
log.Println("[Running shell]")
runShellAs(string(rec.who), string(rec.cmd), true, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Exiting shell]")
} else {
log.Println("[Bad cmdSpec]")
}
return
}(conn)
} //endfor } //endfor
log.Println("[Exiting]") log.Println("[Exiting]")
} }