Removed channel-based server loop goroutine, solving eaten initial byte issue.

Made receivers on hkex.Conn mutators *Conn again (whoops)
TODO: Padding in ciphertext data!
This commit is contained in:
Russ Magee 2018-01-20 20:37:27 -08:00
parent 732005d9bf
commit 3efdd5cfbd
5 changed files with 110 additions and 138 deletions

View File

@ -14,6 +14,14 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type cmdSpec struct {
op []byte
who []byte
cmd []byte
authCookie []byte
status int
}
// Demo of a simple client that dials up to a simple test server to // Demo of a simple client that dials up to a simple test server to
// send data. // send data.
// //
@ -39,6 +47,7 @@ func main() {
flag.StringVar(&server, "s", "localhost:2000", "server hostname/address[:port]") flag.StringVar(&server, "s", "localhost:2000", "server hostname/address[:port]")
flag.Parse() flag.Parse()
//log.SetOutput(os.Stdout)
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
conn, err := hkex.Dial("tcp", server, cAlg, hAlg) conn, err := hkex.Dial("tcp", server, cAlg, hAlg)
@ -60,6 +69,19 @@ func main() {
fmt.Println("NOT A TTY") fmt.Println("NOT A TTY")
} }
rec := &cmdSpec{op: []byte{'s'},
who: []byte("ABCD"),
cmd: []byte("EFGH"),
authCookie: []byte("99"),
status: 0}
_, err = fmt.Fprintf(conn, "%d %d %d %d\n", len(rec.op), len(rec.who), len(rec.cmd), len(rec.authCookie))
_, err = conn.Write(rec.op)
_, err = conn.Write(rec.who)
_, err = conn.Write(rec.cmd)
_, err = conn.Write(rec.authCookie)
//client reader (from server) goroutine
wg.Add(1) wg.Add(1)
go func() { go func() {
// By deferring a call to wg.Done(), // By deferring a call to wg.Done(),
@ -82,11 +104,12 @@ func main() {
} }
} }
if isInteractive { if isInteractive {
log.Println("[Got Write EOF]") log.Println("[Got EOF]")
wg.Done() // client hanging up, close WaitGroup to exit client wg.Done() // server hung up, close WaitGroup to exit client
} }
}() }()
// client writer (to server) goroutine
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -100,8 +123,8 @@ func main() {
os.Exit(2) os.Exit(2)
} }
} }
log.Println("[Got Read EOF]") log.Println("[Sent EOF]")
wg.Done() // server hung up, close WaitGroup to exit client wg.Done() // client hung up, close WaitGroup to exit client
}() }()
// Wait until both stdin and stdout goroutines finish // Wait until both stdin and stdout goroutines finish

View File

@ -62,7 +62,7 @@ func main() {
} }
}(ch, eCh) }(ch, eCh)
ticker := time.Tick(time.Second/100) ticker := time.Tick(time.Second / 100)
Term: Term:
// continuously read from the connection // continuously read from the connection
for { for {

View File

@ -4,13 +4,12 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"os"
"os/exec" "os/exec"
"os/user" "os/user"
"strings" "strings"
"syscall" "syscall"
"time"
hkex "blitter.com/herradurakex" hkex "blitter.com/herradurakex"
"github.com/kr/pty" "github.com/kr/pty"
@ -32,13 +31,13 @@ const (
OpX = 'x' // exec OpX = 'x' // exec
) )
type Op uint8 //type Op uint8
type cmdRunner struct { type cmdSpec struct {
op Op op []byte
who string who []byte
arg string cmd []byte
authCookie string authCookie []byte
status int status int
} }
@ -95,7 +94,7 @@ func main() {
flag.StringVar(&laddr, "l", ":2000", "interface[:port] to listen") flag.StringVar(&laddr, "l", ":2000", "interface[:port] to listen")
flag.Parse() flag.Parse()
log.SetOutput(ioutil.Discard) log.SetOutput(os.Stdout /*ioutil.Discard*/)
// Listen on TCP port 2000 on all available unicast and // Listen on TCP port 2000 on all available unicast and
// anycast IP addresses of the local system. // anycast IP addresses of the local system.
@ -119,85 +118,63 @@ func main() {
// multiple connections may be served concurrently. // multiple connections may be served concurrently.
go func(c hkex.Conn) (e error) { go func(c hkex.Conn) (e error) {
defer c.Close() defer c.Close()
var connOp *byte = nil
ch := make(chan []byte)
chN := 0
eCh := make(chan error)
// Start a goroutine to read from our net connection //We use io.ReadFull() here to guarantee we consume
go func(ch chan []byte, eCh chan error) { //just the data we want for the cmdSpec, and no more.
for { //Otherwise data will be sitting in the channel that isn't
// try to read the data //passed down to the command handlers.
data := make([]byte, 512) var rec cmdSpec
chN, err = c.Read(data) var len1, len2, len3, len4 uint32
if err != nil {
// send an error if it's encountered
eCh <- err
return
}
// send data if we read some.
ch <- data[0:chN]
}
}(ch, eCh)
ticker := time.Tick(time.Second / 100) n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4)
Term: if err != nil || n < 4 {
// continuously read from the connection fmt.Println("[Bad cmdSpec fmt]")
for { return err
select { }
// This case means we recieved data on the connection fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4)
case data := <-ch:
// Do something with the data rec.op = make([]byte, len1, len1)
fmt.Printf("Client sent %+v\n", data[0:chN]) _, err = io.ReadFull(c, rec.op)
if connOp == nil { if err != nil {
// Initial xmit - get op byte fmt.Println("[Bad cmdSpec.op]")
// Have op here and first block of data[] return err
connOp = new(byte) }
*connOp = data[0] rec.who = make([]byte, len2, len2)
fmt.Printf("[* connOp '%c']\n", *connOp) _, err = io.ReadFull(c, rec.who)
} if err != nil {
if len(data) > 1 { fmt.Println("[Bad cmdSpec.who]")
data = data[1:chN] return err
chN -= 1 }
}
rec.cmd = make([]byte, len3, len3)
if len(data) > 0 { _, err = io.ReadFull(c, rec.cmd)
// From here, one could pass all subsequent data if err != nil {
// between client/server attached to an exec.Cmd, fmt.Println("[Bad cmdSpec.cmd]")
// as data to/from a file, etc. return err
if connOp != nil && *connOp == 's' { }
fmt.Println("[Running shell]")
runCmdAs("larissa", "bash -l -i", conn) rec.authCookie = make([]byte, len4, len4)
// Returned hopefully via an EOF or exit/logout; _, err = io.ReadFull(c, rec.authCookie)
// Clear current op so user can enter next, or EOF if err != nil {
connOp = nil fmt.Println("[Bad cmdSpec.authCookie]")
fmt.Println("[Exiting shell]") return err
conn.Close() }
}
if strings.Trim(string(data), "\r\n") == "exit" { fmt.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:%s]\n",
conn.Close() rec.op[0], string(rec.who), string(rec.cmd), string(rec.authCookie))
}
} if rec.op[0] == 's' {
//fmt.Printf("Client sent %s\n", string(data)) fmt.Println("[Running shell]")
// This case means we got an error and the goroutine has finished runCmdAs("larissa", "bash -l -i", conn)
case err := <-eCh: // Returned hopefully via an EOF or exit/logout;
// handle our error then exit for loop // Clear current op so user can enter next, or EOF
if err.Error() == "EOF" { rec.op[0] = 0
fmt.Printf("[Client disconnected]\n") fmt.Println("[Exiting shell]")
} else { } else {
fmt.Printf("Error reading client data! (%+v)\n", err) fmt.Println("[Bad cmdSpec]")
}
break Term
// This will timeout on the read.
case <-ticker:
// do nothing? this is just so we can time out if we need to.
// you probably don't even need to have this here unless you want
// do something specifically on the timeout.
}
} }
// Shut down the connection.
//c.Close()
return return
}(conn) }(conn)
} } //endfor
fmt.Println("[Exiting]")
} }

View File

@ -8,6 +8,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"fmt" "fmt"
"log"
"math/big" "math/big"
"os" "os"
@ -29,16 +30,6 @@ const (
HmacNoneDisallowed HmacNoneDisallowed
) )
type ChanOp uint8
const (
ChanOpNop = '.'
ChanOpEcho = 'e' // For testing - echo client data to stderr
//ChanOpFileWrite = "w"
//ChanOpFileRead = "r"
//ChanOpRemoteCmd = "x"
)
/*TODO: HMAC derived from HKEx FA.*/ /*TODO: HMAC derived from HKEx FA.*/
/* Support functionality to set up encryption after a channel has /* Support functionality to set up encryption after a channel has
been negotiated via hkexnet.go been negotiated via hkexnet.go
@ -59,7 +50,7 @@ func (hc Conn) getStream(keymat *big.Int) (ret cipher.Stream) {
ivlen = aes.BlockSize ivlen = aes.BlockSize
iv := keymat.Bytes()[aes.BlockSize : aes.BlockSize+ivlen] iv := keymat.Bytes()[aes.BlockSize : aes.BlockSize+ivlen]
ret = cipher.NewOFB(block, iv) ret = cipher.NewOFB(block, iv)
fmt.Printf("[cipher AES_256 (%d)]\n", copts) log.Printf("[cipher AES_256 (%d)]\n", copts)
break break
case CAlgTwofish128: case CAlgTwofish128:
key = keymat.Bytes()[0:twofish.BlockSize] key = keymat.Bytes()[0:twofish.BlockSize]
@ -67,7 +58,7 @@ func (hc Conn) getStream(keymat *big.Int) (ret cipher.Stream) {
ivlen = twofish.BlockSize ivlen = twofish.BlockSize
iv := keymat.Bytes()[twofish.BlockSize : twofish.BlockSize+ivlen] iv := keymat.Bytes()[twofish.BlockSize : twofish.BlockSize+ivlen]
ret = cipher.NewOFB(block, iv) ret = cipher.NewOFB(block, iv)
fmt.Printf("[cipher TWOFISH_128 (%d)]\n", copts) log.Printf("[cipher TWOFISH_128 (%d)]\n", copts)
break break
case CAlgBlowfish64: case CAlgBlowfish64:
key = keymat.Bytes()[0:blowfish.BlockSize] key = keymat.Bytes()[0:blowfish.BlockSize]
@ -84,9 +75,10 @@ func (hc Conn) getStream(keymat *big.Int) (ret cipher.Stream) {
// copy what's needed whereas blowfish does no such check. // copy what's needed whereas blowfish does no such check.
iv := keymat.Bytes()[blowfish.BlockSize : blowfish.BlockSize+ivlen] iv := keymat.Bytes()[blowfish.BlockSize : blowfish.BlockSize+ivlen]
ret = cipher.NewOFB(block, iv) ret = cipher.NewOFB(block, iv)
fmt.Printf("[cipher BLOWFISH_64 (%d)]\n", copts) log.Printf("[cipher BLOWFISH_64 (%d)]\n", copts)
break break
default: default:
log.Printf("[invalid cipher (%d)]\n", copts)
fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts) fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts)
os.Exit(1) os.Exit(1)
} }
@ -94,9 +86,10 @@ func (hc Conn) getStream(keymat *big.Int) (ret cipher.Stream) {
hopts := (hc.cipheropts >> 8) & 0xFF hopts := (hc.cipheropts >> 8) & 0xFF
switch hopts { switch hopts {
case HmacSHA256: case HmacSHA256:
fmt.Printf("[nop HmacSHA256 (%d)]\n", hopts) log.Printf("[nop HmacSHA256 (%d)]\n", hopts)
break break
default: default:
log.Printf("[invalid hmac (%d)]\n", hopts)
fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts) fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts)
os.Exit(1) os.Exit(1)
} }

View File

@ -39,7 +39,6 @@ type Conn struct {
h *HerraduraKEx h *HerraduraKEx
cipheropts uint32 // post-KEx cipher/hmac options cipheropts uint32 // post-KEx cipher/hmac options
opts uint32 // post-KEx protocol options (caller-defined) opts uint32 // post-KEx protocol options (caller-defined)
op uint8 // post-KEx 'op' (caller-defined)
r cipher.Stream r cipher.Stream
w cipher.Stream w cipher.Stream
} }
@ -57,7 +56,7 @@ func (c Conn) ConnOpts() uint32 {
// peer as part of KEx but not part of the KEx itself. // peer as part of KEx but not part of the KEx itself.
// //
// opts - bitfields for cipher and hmac alg. to use after KEx // opts - bitfields for cipher and hmac alg. to use after KEx
func (c Conn) SetConnOpts(copts uint32) { func (c *Conn) SetConnOpts(copts uint32) {
c.cipheropts = copts c.cipheropts = copts
} }
@ -77,31 +76,11 @@ func (c Conn) Opts() uint32 {
// of the KEx of encryption info used by the connection. // of the KEx of encryption info used by the connection.
// //
// opts - a uint32, caller-defined // opts - a uint32, caller-defined
func (c Conn) SetOpts(opts uint32) { func (c *Conn) SetOpts(opts uint32) {
c.opts = opts c.opts = opts
} }
// Op returns the 'op' value, which is sent to the peer func (c *Conn) applyConnExtensions(extensions ...string) {
// but is not itself part of the KEx or connection (cipher/hmac) setup.
//
// Consumers of this lib may use this to indicate connection-specific
// operations not part of the KEx or encryption info used by the connection.
func (c Conn) Op() uint8 {
return c.op
}
// SetOp sets the 'op' value, which is sent to the peer
// but is not itself part of the KEx or connection (cipher/hmac) setup.
//
// Consumers of this lib may use this to indicate connection-specific
// operations not part of the KEx or encryption info used by the connection.
//
// op - a uint8, caller-defined
func (c Conn) SetOp(op uint8) {
c.op = op
}
func (c Conn) applyConnExtensions(extensions ...string) {
for _, s := range extensions { for _, s := range extensions {
switch s { switch s {
case "C_AES_256": case "C_AES_256":
@ -143,20 +122,20 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e
if err != nil { if err != nil {
return nil, err return nil, err
} }
hc = &Conn{c: c, h: New(0, 0), cipheropts: 0, opts: 0, op: 0, r: nil, w: nil} hc = &Conn{c: c, h: New(0, 0), cipheropts: 0, opts: 0, r: nil, w: nil}
hc.applyConnExtensions(extensions...) hc.applyConnExtensions(extensions...)
fmt.Fprintf(c, "0x%s\n%08x:%08x:%02x\n", hc.h.d.Text(16), fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16),
hc.cipheropts, hc.opts, hc.op) hc.cipheropts, hc.opts)
d := big.NewInt(0) d := big.NewInt(0)
_, err = fmt.Fscanln(c, d) _, err = fmt.Fscanln(c, d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = fmt.Fscanf(c, "%08x:%08x:%02x\n", _, err = fmt.Fscanf(c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts, &hc.op) &hc.cipheropts, &hc.opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,15 +244,15 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
} }
log.Println("[Accepted]") log.Println("[Accepted]")
hc = Conn{c: c, h: New(0, 0), cipheropts: 0, opts: 0, op: 0, r: nil, w: nil} hc = Conn{c: c, h: New(0, 0), cipheropts: 0, opts: 0, r: nil, w: nil}
d := big.NewInt(0) d := big.NewInt(0)
_, err = fmt.Fscanln(c, d) _, err = fmt.Fscanln(c, d)
if err != nil { if err != nil {
return hc, err return hc, err
} }
_, err = fmt.Fscanf(c, "%08x:%08x:%02x\n", _, err = fmt.Fscanf(c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts, &hc.op) &hc.cipheropts, &hc.opts)
if err != nil { if err != nil {
return hc, err return hc, err
} }
@ -283,8 +262,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
hc.h.FA() hc.h.FA()
log.Printf("**(s)** FA:%s\n", hc.h.fa) log.Printf("**(s)** FA:%s\n", hc.h.fa)
fmt.Fprintf(c, "0x%s\n%08x:%08x:%02x\n", hc.h.d.Text(16), fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16),
hc.cipheropts, hc.opts, hc.op) hc.cipheropts, hc.opts)
hc.r = hc.getStream(hc.h.fa) hc.r = hc.getStream(hc.h.fa)
hc.w = hc.getStream(hc.h.fa) hc.w = hc.getStream(hc.h.fa)