diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index e678e07..35387cc 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -50,7 +50,8 @@ import ( ) /*---------------------------------------------------------------------*/ -const PAD_SZ = 32 +const PAD_SZ = 32 // max size of padding applied to each packet +const HMAC_CHK_SZ = 4 // leading bytes of HMAC to xmit for verification type ( WinSize struct { @@ -145,31 +146,24 @@ func getkexalgnum(extensions ...string) (k KEXAlg) { for _, s := range extensions { switch s { case "KEX_HERRADURA256": - log.Println("[extension arg = KEX_HERRADURA256]") k = KEX_HERRADURA256 break //out of for case "KEX_HERRADURA512": - log.Println("[extension arg = KEX_HERRADURA512]") k = KEX_HERRADURA512 break //out of for case "KEX_HERRADURA1024": - log.Println("[extension arg = KEX_HERRADURA1024]") k = KEX_HERRADURA1024 break //out of for case "KEX_HERRADURA2048": - log.Println("[extension arg = KEX_HERRADURA2048]") k = KEX_HERRADURA2048 break //out of for case "KEX_KYBER512": - log.Println("[extension arg = KEX_KYBER512]") k = KEX_KYBER512 break //out of for case "KEX_KYBER768": - log.Println("[extension arg = KEX_KYBER768]") k = KEX_KYBER768 break //out of for case "KEX_KYBER1024": - log.Println("[extension arg = KEX_KYBER1024]") k = KEX_KYBER1024 break //out of for } @@ -208,6 +202,7 @@ func _new(kexAlg KEXAlg, conn *net.Conn) (hc *Conn, e error) { case KEX_KYBER1024: log.Printf("[KEx alg %d accepted]\n", kexAlg) default: + // UNREACHABLE: _getkexalgnum() guarantees a valid KEX value hc.kex = KEX_HERRADURA256 log.Printf("[KEx alg %d ?? defaults to %d]\n", kexAlg, hc.kex) } @@ -665,7 +660,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { } var ctrlStatOp uint8 - var hmacIn [4]uint8 + var hmacIn [HMAC_CHK_SZ]uint8 var payloadLen uint32 // Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch) @@ -766,7 +761,7 @@ func (hc Conn) Read(b []byte) (n int, err error) { //log.Printf("hc.dBuf: %s\n", hex.Dump(hc.dBuf.Bytes())) } - hTmp := hc.rm.Sum(nil)[0:4] + hTmp := hc.rm.Sum(nil)[0:HMAC_CHK_SZ] log.Printf("<%04x) HMAC:(i)%s (c)%02x\r\n", decryptN, hex.EncodeToString([]byte(hmacIn[0:])), hTmp) if *hc.closeStat == CSETruncCSO { @@ -849,7 +844,7 @@ func (hc *Conn) WritePacket(b []byte, op byte) (n int, err error) { // Calculate hmac on payload hc.wm.Write(b[0:payloadLen]) - hmacOut = hc.wm.Sum(nil)[0:4] + hmacOut = hc.wm.Sum(nil)[0:HMAC_CHK_SZ] log.Printf(" (%04x> HMAC(o):%s\r\n", payloadLen, hex.EncodeToString(hmacOut))