aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/net/wireguard/noise.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/noise.c')
-rw-r--r--drivers/net/wireguard/noise.c22
1 files changed, 9 insertions, 13 deletions
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 708dc61c974f..626433690abb 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer)
if (unlikely(!keypair))
return NULL;
+ spin_lock_init(&keypair->receiving_counter.lock);
keypair->internal_id = atomic64_inc_return(&keypair_counter);
keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
keypair->entry.peer = peer;
@@ -358,25 +359,16 @@ out:
memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
}
-static void symmetric_key_init(struct noise_symmetric_key *key)
-{
- spin_lock_init(&key->counter.receive.lock);
- atomic64_set(&key->counter.counter, 0);
- memset(key->counter.receive.backtrack, 0,
- sizeof(key->counter.receive.backtrack));
- key->birthdate = ktime_get_coarse_boottime_ns();
- key->is_valid = true;
-}
-
static void derive_keys(struct noise_symmetric_key *first_dst,
struct noise_symmetric_key *second_dst,
const u8 chaining_key[NOISE_HASH_LEN])
{
+ u64 birthdate = ktime_get_coarse_boottime_ns();
kdf(first_dst->key, second_dst->key, NULL, NULL,
NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
chaining_key);
- symmetric_key_init(first_dst);
- symmetric_key_init(second_dst);
+ first_dst->birthdate = second_dst->birthdate = birthdate;
+ first_dst->is_valid = second_dst->is_valid = true;
}
static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
@@ -715,6 +707,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
u8 e[NOISE_PUBLIC_KEY_LEN];
u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
u8 static_private[NOISE_PUBLIC_KEY_LEN];
+ u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
down_read(&wg->static_identity.lock);
@@ -733,6 +726,8 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
memcpy(ephemeral_private, handshake->ephemeral_private,
NOISE_PUBLIC_KEY_LEN);
+ memcpy(preshared_key, handshake->preshared_key,
+ NOISE_SYMMETRIC_KEY_LEN);
up_read(&handshake->lock);
if (state != HANDSHAKE_CREATED_INITIATION)
@@ -750,7 +745,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
goto fail;
/* psk */
- mix_psk(chaining_key, hash, key, handshake->preshared_key);
+ mix_psk(chaining_key, hash, key, preshared_key);
/* {} */
if (!message_decrypt(NULL, src->encrypted_nothing,
@@ -783,6 +778,7 @@ out:
memzero_explicit(chaining_key, NOISE_HASH_LEN);
memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
+ memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
up_read(&wg->static_identity.lock);
return ret_peer;
}