-Added error checking for all stages of hkex.Conn.Accept() and GetStream()

-Server will log such errors without panic/exit
-Const added but not yet used for 'chaff' packets
This commit is contained in:
Russ Magee 2018-04-28 16:05:33 -07:00
parent c56d4d9ad9
commit 50f0433579
5 changed files with 132 additions and 110 deletions

View file

@ -15,11 +15,11 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"errors"
"fmt"
"hash"
"log"
"math/big"
"os"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/twofish"
@ -46,12 +46,11 @@ const (
/* Support functionality to set up encryption after a channel has
been negotiated via hkexnet.go
*/
func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash, err error) {
var key []byte
var block cipher.Block
var iv []byte
var ivlen int
var err error
copts := hc.cipheropts & 0xFF
// TODO: each cipher alg case should ensure len(keymat.Bytes())
@ -93,7 +92,8 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default:
log.Printf("[invalid cipher (%d)]\n", copts)
fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts)
os.Exit(1)
err = errors.New("hkexchan: INVALID CIPHER ALG")
//os.Exit(1)
}
hopts := (hc.cipheropts >> 8) & 0xFF
@ -109,20 +109,19 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default:
log.Printf("[invalid hmac (%d)]\n", hopts)
fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts)
os.Exit(1)
err = errors.New("hkexchan: INVALID HMAC ALG")
return
//os.Exit(1)
}
if err != nil {
panic(err)
// Feed the IV into the hmac: all traffic in the connection must
// feed its data into the hmac afterwards, so both ends can xor
// that with the stream to detect corruption.
_, _ = mc.Write(iv)
var currentHash []byte
currentHash = mc.Sum(currentHash)
log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash))
}
// Feed the IV into the hmac: all traffic in the connection must
// feed its data into the hmac afterwards, so both ends can xor
// that with the stream to detect corruption.
_, _ = mc.Write(iv)
var currentHash []byte
currentHash = mc.Sum(currentHash)
log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash))
return
}

View file

@ -28,6 +28,12 @@ import (
"time"
)
const (
csoNone = iota // No error, normal packet
csoHmacInvalid // HMAC mismatch detect on remote end
csoChaff // This packet is a dummy, do not process beyond decryption
)
/*---------------------------------------------------------------------*/
// Conn is a HKex connection - a drop-in replacement for net.Conn
@ -149,8 +155,8 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e
hc.h.FA()
log.Printf("**(c)** FA:%s\n", hc.h.fa)
hc.r, hc.rm = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa)
hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return
}
@ -262,11 +268,13 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
// d is value for Herradura key exchange
d := big.NewInt(0)
_, err = fmt.Fscanln(c, d)
log.Printf("[Got d:%v]", d)
if err != nil {
return hc, err
}
_, err = fmt.Fscanf(c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts)
log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts)
if err != nil {
return hc, err
}
@ -279,8 +287,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16),
hc.cipheropts, hc.opts)
hc.r, hc.rm = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa)
hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return
}
@ -303,9 +311,9 @@ func (c Conn) Read(b []byte) (n int, err error) {
var hmacIn [4]uint8
var payloadLen uint32
// Read ctrl/status opcode (for now, set nonzero on hmac mismatch)
// Read ctrl/status opcode (csoHmacInvalid on hmac mismatch)
err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp)
if ctrlStatOp != 0 {
if ctrlStatOp == csoHmacInvalid {
// Other side indicated channel tampering, close channel
c.Close()
return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **")
@ -328,14 +336,19 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil {
if err.Error() != "EOF" {
panic(err)
} // else {
// return 0, err
//}
// Cannot just return 0, err here - client won't hang up properly
// when 'exit' from shell. TODO: try server sending ctrlStatOp to
// indicate to Reader? -rlm 20180428
}
}
if payloadLen > 16384 {
panic("Insane payloadLen")
log.Printf("[Insane payloadLen:%v]\n", payloadLen)
c.Close()
return 1, errors.New("Insane payloadLen")
}
//log.Println("payloadLen:", payloadLen)
var payloadBytes = make([]byte, payloadLen)
n, err = io.ReadFull(c.c, payloadBytes)
//log.Print(" << Read ", n, " payloadBytes")
@ -364,8 +377,14 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil {
panic(err)
}
c.dBuf.Write(payloadBytes)
//log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes()))
// Throw away pkt if it's chaff (ie., caller to Read() won't see this data)
if ctrlStatOp == csoChaff {
log.Printf("[Chaff pkt]\n")
} else {
c.dBuf.Write(payloadBytes)
//log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes()))
}
// Re-calculate hmac, compare with received value
c.rm.Write(payloadBytes)
@ -374,10 +393,11 @@ func (c Conn) Read(b []byte) (n int, err error) {
// Log alert if hmac didn't match, corrupted channel
if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ {
fmt.Println("** ALERT - hmac mismatch, possible channel tampering **")
_, _ = c.c.Write([]byte{0x1})
fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **")
_, _ = c.c.Write([]byte{csoHmacInvalid})
}
}
retN := c.dBuf.Len()
if retN > len(b) {
retN = len(b)
@ -416,11 +436,11 @@ func (c Conn) Write(b []byte) (n int, err error) {
panic(err)
}
log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes()))
var ctrlStatOp byte
ctrlStatOp = 0x00
ctrlStatOp = csoNone
_ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp)
// Write hmac LSB, payloadLen followed by payload
_ = binary.Write(c.c, binary.BigEndian, hmacOut)
_ = binary.Write(c.c, binary.BigEndian, payloadLen)

View file

@ -19,7 +19,7 @@ import (
"os/user"
"github.com/jameskeane/bcrypt"
hkexsh "blitter.com/hkexsh"
hkexsh "blitter.com/go/hkexsh"
)
func main() {

View file

@ -18,7 +18,7 @@ import (
"strings"
"sync"
hkexsh "blitter.com/hkexsh"
hkexsh "blitter.com/go/hkexsh"
isatty "github.com/mattn/go-isatty"
)

View file

@ -18,8 +18,8 @@ import (
"os/user"
"syscall"
hkexsh "blitter.com/hkexsh"
"blitter.com/hkexsh/spinsult"
hkexsh "blitter.com/go/hkexsh"
"blitter.com/go/hkexsh/spinsult"
"github.com/kr/pty"
)
@ -171,90 +171,93 @@ func main() {
// Wait for a connection.
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
}
log.Println("Accepted client")
log.Printf("Accept() got error(%v), hanging up.\n", err)
conn.Close()
//log.Fatal(err)
} else {
log.Println("Accepted client")
// Handle the connection in a new goroutine.
// The loop then returns to accepting, so that
// multiple connections may be served concurrently.
go func(c hkexsh.Conn) (e error) {
defer c.Close()
// Handle the connection in a new goroutine.
// The loop then returns to accepting, so that
// multiple connections may be served concurrently.
go func(c hkexsh.Conn) (e error) {
defer c.Close()
//We use io.ReadFull() here to guarantee we consume
//just the data we want for the cmdSpec, and no more.
//Otherwise data will be sitting in the channel that isn't
//passed down to the command handlers.
var rec cmdSpec
var len1, len2, len3, len4 uint32
//We use io.ReadFull() here to guarantee we consume
//just the data we want for the cmdSpec, and no more.
//Otherwise data will be sitting in the channel that isn't
//passed down to the command handlers.
var rec cmdSpec
var len1, len2, len3, len4 uint32
n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4)
log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4)
n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4)
log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4)
if err != nil || n < 4 {
log.Println("[Bad cmdSpec fmt]")
return err
}
//fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4)
if err != nil || n < 4 {
log.Println("[Bad cmdSpec fmt]")
return err
}
//fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4)
rec.op = make([]byte, len1, len1)
_, err = io.ReadFull(c, rec.op)
if err != nil {
log.Println("[Bad cmdSpec.op]")
return err
}
rec.who = make([]byte, len2, len2)
_, err = io.ReadFull(c, rec.who)
if err != nil {
log.Println("[Bad cmdSpec.who]")
return err
}
rec.op = make([]byte, len1, len1)
_, err = io.ReadFull(c, rec.op)
if err != nil {
log.Println("[Bad cmdSpec.op]")
return err
}
rec.who = make([]byte, len2, len2)
_, err = io.ReadFull(c, rec.who)
if err != nil {
log.Println("[Bad cmdSpec.who]")
return err
}
rec.cmd = make([]byte, len3, len3)
_, err = io.ReadFull(c, rec.cmd)
if err != nil {
log.Println("[Bad cmdSpec.cmd]")
return err
}
rec.cmd = make([]byte, len3, len3)
_, err = io.ReadFull(c, rec.cmd)
if err != nil {
log.Println("[Bad cmdSpec.cmd]")
return err
}
rec.authCookie = make([]byte, len4, len4)
_, err = io.ReadFull(c, rec.authCookie)
if err != nil {
log.Println("[Bad cmdSpec.authCookie]")
return err
}
rec.authCookie = make([]byte, len4, len4)
_, err = io.ReadFull(c, rec.authCookie)
if err != nil {
log.Println("[Bad cmdSpec.authCookie]")
return err
}
log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n",
rec.op[0], string(rec.who), string(rec.cmd))
log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n",
rec.op[0], string(rec.who), string(rec.cmd))
valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd")
if !valid {
log.Println("Invalid user", string(rec.who))
c.Write([]byte(rejectUserMsg()))
valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd")
if !valid {
log.Println("Invalid user", string(rec.who))
c.Write([]byte(rejectUserMsg()))
return
}
log.Printf("[allowedCmds:%s]\n", allowedCmds)
if rec.op[0] == 'c' {
// Non-interactive command
log.Println("[Running command]")
runShellAs(string(rec.who), string(rec.cmd), false, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Command complete]")
} else if rec.op[0] == 's' {
log.Println("[Running shell]")
runShellAs(string(rec.who), string(rec.cmd), true, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Exiting shell]")
} else {
log.Println("[Bad cmdSpec]")
}
return
}
log.Printf("[allowedCmds:%s]\n", allowedCmds)
if rec.op[0] == 'c' {
// Non-interactive command
log.Println("[Running command]")
runShellAs(string(rec.who), string(rec.cmd), false, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Command complete]")
} else if rec.op[0] == 's' {
log.Println("[Running shell]")
runShellAs(string(rec.who), string(rec.cmd), true, conn)
// Returned hopefully via an EOF or exit/logout;
// Clear current op so user can enter next, or EOF
rec.op[0] = 0
log.Println("[Exiting shell]")
} else {
log.Println("[Bad cmdSpec]")
}
return
}(conn)
}(conn)
} // Accept() success
} //endfor
log.Println("[Exiting]")
}