diff --git a/src/liblsquic/lsquic_enc_sess_ietf.c b/src/liblsquic/lsquic_enc_sess_ietf.c index 21c2063..2e39941 100644 --- a/src/liblsquic/lsquic_enc_sess_ietf.c +++ b/src/liblsquic/lsquic_enc_sess_ietf.c @@ -115,11 +115,15 @@ struct header_prot const EVP_CIPHER *hp_cipher; gen_hp_mask_f hp_gen_mask; enum enc_level hp_enc_level; + enum { + HP_CAN_READ = 1 << 0, + HP_CAN_WRITE = 1 << 1, + } hp_flags; unsigned hp_sz; unsigned char hp_buf[2][EVP_MAX_KEY_LENGTH]; }; -#define header_prot_inited(hp_) ((hp_)->hp_sz > 0) +#define header_prot_inited(hp_, rw_) ((hp_)->hp_flags & (1 << (rw_))) struct crypto_ctx @@ -177,12 +181,11 @@ derive_hp_secrets (struct header_prot *hp, const EVP_MD *md, const unsigned char *client_secret, const unsigned char *server_secret) { hp->hp_sz = EVP_AEAD_key_length(aead); - if (client_secret) - lsquic_qhkdf_expand(md, client_secret, secret_sz, PN_LABEL, PN_LABEL_SZ, - hp->hp_buf[0], hp->hp_sz); - if (server_secret) - lsquic_qhkdf_expand(md, server_secret, secret_sz, PN_LABEL, PN_LABEL_SZ, - hp->hp_buf[1], hp->hp_sz); + hp->hp_flags = HP_CAN_READ | HP_CAN_WRITE; + lsquic_qhkdf_expand(md, client_secret, secret_sz, PN_LABEL, PN_LABEL_SZ, + hp->hp_buf[0], hp->hp_sz); + lsquic_qhkdf_expand(md, server_secret, secret_sz, PN_LABEL, PN_LABEL_SZ, + hp->hp_buf[1], hp->hp_sz); } @@ -1914,7 +1917,7 @@ iquic_esf_decrypt_packet (enc_session_t *enc_session_p, else hp = NULL; - if (UNLIKELY(!(hp && header_prot_inited(hp)))) + if (UNLIKELY(!(hp && header_prot_inited(hp, 0)))) { LSQ_DEBUG("header protection for level %u not initialized yet", enc_level); @@ -1944,7 +1947,11 @@ iquic_esf_decrypt_packet (enc_session_t *enc_session_p, key_phase = (dst[0] & 0x04) > 0; pair = &enc_sess->esi_pairs[ key_phase ]; if (key_phase == enc_sess->esi_key_phase) + { crypto_ctx = &pair->ykp_ctx[ 0 ]; + /* Checked by header_prot_inited() above */ + assert(crypto_ctx->yk_flags & YK_INITED); + } else if (!is_valid_packno( enc_sess->esi_pairs[enc_sess->esi_key_phase].ykp_thresh) || packet_in->pi_packno @@ -2526,6 +2533,7 @@ set_secret (SSL *ssl, enum ssl_encryption_level_t level, } lsquic_qhkdf_expand(crypa.md, secret, secret_len, PN_LABEL, PN_LABEL_SZ, hp->hp_buf[rw], hp->hp_sz); + hp->hp_flags |= 1 << rw; if (enc_sess->esi_flags & ESI_LOG_SECRETS) {