aboutsummaryrefslogtreecommitdiffstats
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/messages.h2
-rw-r--r--drivers/net/wireguard/noise.c22
-rw-r--r--drivers/net/wireguard/noise.h14
-rw-r--r--drivers/net/wireguard/queueing.c4
-rw-r--r--drivers/net/wireguard/queueing.h10
-rw-r--r--drivers/net/wireguard/receive.c52
-rw-r--r--drivers/net/wireguard/selftest/counter.c17
-rw-r--r--drivers/net/wireguard/send.c23
-rw-r--r--drivers/net/wireguard/socket.c12
9 files changed, 81 insertions, 75 deletions
diff --git a/drivers/net/wireguard/messages.h b/drivers/net/wireguard/messages.h
index b8a7b9ce32ba..208da72673fc 100644
--- a/drivers/net/wireguard/messages.h
+++ b/drivers/net/wireguard/messages.h
@@ -32,7 +32,7 @@ enum cookie_values {
};
enum counter_values {
- COUNTER_BITS_TOTAL = 2048,
+ COUNTER_BITS_TOTAL = 8192,
COUNTER_REDUNDANT_BITS = BITS_PER_LONG,
COUNTER_WINDOW_SIZE = COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS
};
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 708dc61c974f..626433690abb 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer)
if (unlikely(!keypair))
return NULL;
+ spin_lock_init(&keypair->receiving_counter.lock);
keypair->internal_id = atomic64_inc_return(&keypair_counter);
keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
keypair->entry.peer = peer;
@@ -358,25 +359,16 @@ out:
memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
}
-static void symmetric_key_init(struct noise_symmetric_key *key)
-{
- spin_lock_init(&key->counter.receive.lock);
- atomic64_set(&key->counter.counter, 0);
- memset(key->counter.receive.backtrack, 0,
- sizeof(key->counter.receive.backtrack));
- key->birthdate = ktime_get_coarse_boottime_ns();
- key->is_valid = true;
-}
-
static void derive_keys(struct noise_symmetric_key *first_dst,
struct noise_symmetric_key *second_dst,
const u8 chaining_key[NOISE_HASH_LEN])
{
+ u64 birthdate = ktime_get_coarse_boottime_ns();
kdf(first_dst->key, second_dst->key, NULL, NULL,
NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
chaining_key);
- symmetric_key_init(first_dst);
- symmetric_key_init(second_dst);
+ first_dst->birthdate = second_dst->birthdate = birthdate;
+ first_dst->is_valid = second_dst->is_valid = true;
}
static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
@@ -715,6 +707,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
u8 e[NOISE_PUBLIC_KEY_LEN];
u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
u8 static_private[NOISE_PUBLIC_KEY_LEN];
+ u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
down_read(&wg->static_identity.lock);
@@ -733,6 +726,8 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
memcpy(ephemeral_private, handshake->ephemeral_private,
NOISE_PUBLIC_KEY_LEN);
+ memcpy(preshared_key, handshake->preshared_key,
+ NOISE_SYMMETRIC_KEY_LEN);
up_read(&handshake->lock);
if (state != HANDSHAKE_CREATED_INITIATION)
@@ -750,7 +745,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
goto fail;
/* psk */
- mix_psk(chaining_key, hash, key, handshake->preshared_key);
+ mix_psk(chaining_key, hash, key, preshared_key);
/* {} */
if (!message_decrypt(NULL, src->encrypted_nothing,
@@ -783,6 +778,7 @@ out:
memzero_explicit(chaining_key, NOISE_HASH_LEN);
memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
+ memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
up_read(&wg->static_identity.lock);
return ret_peer;
}
diff --git a/drivers/net/wireguard/noise.h b/drivers/net/wireguard/noise.h
index f532d59d3f19..c527253dba80 100644
--- a/drivers/net/wireguard/noise.h
+++ b/drivers/net/wireguard/noise.h
@@ -15,18 +15,14 @@
#include <linux/mutex.h>
#include <linux/kref.h>
-union noise_counter {
- struct {
- u64 counter;
- unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
- spinlock_t lock;
- } receive;
- atomic64_t counter;
+struct noise_replay_counter {
+ u64 counter;
+ spinlock_t lock;
+ unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
};
struct noise_symmetric_key {
u8 key[NOISE_SYMMETRIC_KEY_LEN];
- union noise_counter counter;
u64 birthdate;
bool is_valid;
};
@@ -34,7 +30,9 @@ struct noise_symmetric_key {
struct noise_keypair {
struct index_hashtable_entry entry;
struct noise_symmetric_key sending;
+ atomic64_t sending_counter;
struct noise_symmetric_key receiving;
+ struct noise_replay_counter receiving_counter;
__le32 remote_index;
bool i_am_the_initiator;
struct kref refcount;
diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c
index 5c964fcb994e..71b8e80b58e1 100644
--- a/drivers/net/wireguard/queueing.c
+++ b/drivers/net/wireguard/queueing.c
@@ -35,8 +35,10 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function,
if (multicore) {
queue->worker = wg_packet_percpu_multicore_worker_alloc(
function, queue);
- if (!queue->worker)
+ if (!queue->worker) {
+ ptr_ring_cleanup(&queue->ring, NULL);
return -ENOMEM;
+ }
} else {
INIT_WORK(&queue->work, function);
}
diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h
index 3432232afe06..c58df439dbbe 100644
--- a/drivers/net/wireguard/queueing.h
+++ b/drivers/net/wireguard/queueing.h
@@ -87,12 +87,20 @@ static inline bool wg_check_packet_protocol(struct sk_buff *skb)
return real_protocol && skb->protocol == real_protocol;
}
-static inline void wg_reset_packet(struct sk_buff *skb)
+static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
{
+ u8 l4_hash = skb->l4_hash;
+ u8 sw_hash = skb->sw_hash;
+ u32 hash = skb->hash;
skb_scrub_packet(skb, true);
memset(&skb->headers_start, 0,
offsetof(struct sk_buff, headers_end) -
offsetof(struct sk_buff, headers_start));
+ if (encapsulating) {
+ skb->l4_hash = l4_hash;
+ skb->sw_hash = sw_hash;
+ skb->hash = hash;
+ }
skb->queue_mapping = 0;
skb->nohdr = 0;
skb->peeked = 0;
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index da3b782ab7d3..474bb69f0e1b 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -246,20 +246,20 @@ static void keep_key_fresh(struct wg_peer *peer)
}
}
-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
+static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
{
struct scatterlist sg[MAX_SKB_FRAGS + 8];
struct sk_buff *trailer;
unsigned int offset;
int num_frags;
- if (unlikely(!key))
+ if (unlikely(!keypair))
return false;
- if (unlikely(!READ_ONCE(key->is_valid) ||
- wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
- key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
- WRITE_ONCE(key->is_valid, false);
+ if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
+ wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
+ keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+ WRITE_ONCE(keypair->receiving.is_valid, false);
return false;
}
@@ -284,7 +284,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
PACKET_CB(skb)->nonce,
- key->key))
+ keypair->receiving.key))
return false;
/* Another ugly situation of pushing and pulling the header so as to
@@ -299,41 +299,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
}
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static bool counter_validate(union noise_counter *counter, u64 their_counter)
+static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
{
unsigned long index, index_current, top, i;
bool ret = false;
- spin_lock_bh(&counter->receive.lock);
+ spin_lock_bh(&counter->lock);
- if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+ if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
their_counter >= REJECT_AFTER_MESSAGES))
goto out;
++their_counter;
if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
- counter->receive.counter))
+ counter->counter))
goto out;
index = their_counter >> ilog2(BITS_PER_LONG);
- if (likely(their_counter > counter->receive.counter)) {
- index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+ if (likely(their_counter > counter->counter)) {
+ index_current = counter->counter >> ilog2(BITS_PER_LONG);
top = min_t(unsigned long, index - index_current,
COUNTER_BITS_TOTAL / BITS_PER_LONG);
for (i = 1; i <= top; ++i)
- counter->receive.backtrack[(i + index_current) &
+ counter->backtrack[(i + index_current) &
((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
- counter->receive.counter = their_counter;
+ counter->counter = their_counter;
}
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
- &counter->receive.backtrack[index]);
+ &counter->backtrack[index]);
out:
- spin_unlock_bh(&counter->receive.lock);
+ spin_unlock_bh(&counter->lock);
return ret;
}
@@ -393,13 +393,11 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
len = ntohs(ip_hdr(skb)->tot_len);
if (unlikely(len < sizeof(struct iphdr)))
goto dishonest_packet_size;
- if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
- IP_ECN_set_ce(ip_hdr(skb));
+ INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos);
} else if (skb->protocol == htons(ETH_P_IPV6)) {
len = ntohs(ipv6_hdr(skb)->payload_len) +
sizeof(struct ipv6hdr);
- if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
- IP6_ECN_set_ce(skb, ipv6_hdr(skb));
+ INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb)));
} else {
goto dishonest_packet_type;
}
@@ -475,19 +473,19 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
if (unlikely(state != PACKET_STATE_CRYPTED))
goto next;
- if (unlikely(!counter_validate(&keypair->receiving.counter,
+ if (unlikely(!counter_validate(&keypair->receiving_counter,
PACKET_CB(skb)->nonce))) {
net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
peer->device->dev->name,
PACKET_CB(skb)->nonce,
- keypair->receiving.counter.receive.counter);
+ keypair->receiving_counter.counter);
goto next;
}
if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
goto next;
- wg_reset_packet(skb);
+ wg_reset_packet(skb, false);
wg_packet_consume_data_done(peer, skb, &endpoint);
free = false;
@@ -514,10 +512,12 @@ void wg_packet_decrypt_worker(struct work_struct *work)
struct sk_buff *skb;
while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
- enum packet_state state = likely(decrypt_packet(skb,
- &PACKET_CB(skb)->keypair->receiving)) ?
+ enum packet_state state =
+ likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
wg_queue_enqueue_per_peer_napi(skb, state);
+ if (need_resched())
+ cond_resched();
}
}
diff --git a/drivers/net/wireguard/selftest/counter.c b/drivers/net/wireguard/selftest/counter.c
index f4fbb9072ed7..ec3c156bf91b 100644
--- a/drivers/net/wireguard/selftest/counter.c
+++ b/drivers/net/wireguard/selftest/counter.c
@@ -6,18 +6,24 @@
#ifdef DEBUG
bool __init wg_packet_counter_selftest(void)
{
+ struct noise_replay_counter *counter;
unsigned int test_num = 0, i;
- union noise_counter counter;
bool success = true;
-#define T_INIT do { \
- memset(&counter, 0, sizeof(union noise_counter)); \
- spin_lock_init(&counter.receive.lock); \
+ counter = kmalloc(sizeof(*counter), GFP_KERNEL);
+ if (unlikely(!counter)) {
+ pr_err("nonce counter self-test malloc: FAIL\n");
+ return false;
+ }
+
+#define T_INIT do { \
+ memset(counter, 0, sizeof(*counter)); \
+ spin_lock_init(&counter->lock); \
} while (0)
#define T_LIM (COUNTER_WINDOW_SIZE + 1)
#define T(n, v) do { \
++test_num; \
- if (counter_validate(&counter, n) != (v)) { \
+ if (counter_validate(counter, n) != (v)) { \
pr_err("nonce counter self-test %u: FAIL\n", \
test_num); \
success = false; \
@@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(void)
if (success)
pr_info("nonce counter self-tests: pass\n");
+ kfree(counter);
return success;
}
#endif
diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c
index 7348c10cbae3..485d5d7a217b 100644
--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_peer *peer)
rcu_read_lock_bh();
keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
if (likely(keypair && READ_ONCE(keypair->sending.is_valid)) &&
- (unlikely(atomic64_read(&keypair->sending.counter.counter) >
+ (unlikely(atomic64_read(&keypair->sending_counter) >
REKEY_AFTER_MESSAGES) ||
(keypair->i_am_the_initiator &&
unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
@@ -170,6 +170,11 @@ static bool encrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
struct sk_buff *trailer;
int num_frags;
+ /* Force hash calculation before encryption so that flow analysis is
+ * consistent over the inner packet.
+ */
+ skb_get_hash(skb);
+
/* Calculate lengths. */
padding_len = calculate_skb_padding(skb);
trailer_len = padding_len + noise_encrypted_len(0);
@@ -281,6 +286,8 @@ void wg_packet_tx_worker(struct work_struct *work)
wg_noise_keypair_put(keypair, false);
wg_peer_put(peer);
+ if (need_resched())
+ cond_resched();
}
}
@@ -296,7 +303,7 @@ void wg_packet_encrypt_worker(struct work_struct *work)
skb_list_walk_safe(first, skb, next) {
if (likely(encrypt_packet(skb,
PACKET_CB(first)->keypair))) {
- wg_reset_packet(skb);
+ wg_reset_packet(skb, true);
} else {
state = PACKET_STATE_DEAD;
break;
@@ -305,6 +312,8 @@ void wg_packet_encrypt_worker(struct work_struct *work)
wg_queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first,
state);
+ if (need_resched())
+ cond_resched();
}
}
@@ -344,7 +353,6 @@ void wg_packet_purge_staged_packets(struct wg_peer *peer)
void wg_packet_send_staged_packets(struct wg_peer *peer)
{
- struct noise_symmetric_key *key;
struct noise_keypair *keypair;
struct sk_buff_head packets;
struct sk_buff *skb;
@@ -364,10 +372,9 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
rcu_read_unlock_bh();
if (unlikely(!keypair))
goto out_nokey;
- key = &keypair->sending;
- if (unlikely(!READ_ONCE(key->is_valid)))
+ if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
goto out_nokey;
- if (unlikely(wg_birthdate_has_expired(key->birthdate,
+ if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
REJECT_AFTER_TIME)))
goto out_invalid;
@@ -382,7 +389,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
*/
PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
PACKET_CB(skb)->nonce =
- atomic64_inc_return(&key->counter.counter) - 1;
+ atomic64_inc_return(&keypair->sending_counter) - 1;
if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
goto out_invalid;
}
@@ -394,7 +401,7 @@ void wg_packet_send_staged_packets(struct wg_peer *peer)
return;
out_invalid:
- WRITE_ONCE(key->is_valid, false);
+ WRITE_ONCE(keypair->sending.is_valid, false);
out_nokey:
wg_noise_keypair_put(keypair, false);
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index b0d6541582d3..f9018027fc13 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -76,12 +76,6 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
wg->dev->name, &endpoint->addr, ret);
goto err;
- } else if (unlikely(rt->dst.dev == skb->dev)) {
- ip_rt_put(rt);
- ret = -ELOOP;
- net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
- wg->dev->name, &endpoint->addr);
- goto err;
}
if (cache)
dst_cache_set_ip4(cache, &rt->dst, fl.saddr);
@@ -149,12 +143,6 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
wg->dev->name, &endpoint->addr, ret);
goto err;
- } else if (unlikely(dst->dev == skb->dev)) {
- dst_release(dst);
- ret = -ELOOP;
- net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
- wg->dev->name, &endpoint->addr);
- goto err;
}
if (cache)
dst_cache_set_ip6(cache, dst, &fl.saddr);