summaryrefslogtreecommitdiffstats
path: root/net/mptcp/token_test.c
blob: e1bd6f0a0676fa1f6ecec42dc9e24f97247b33fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// SPDX-License-Identifier: GPL-2.0
#include <kunit/test.h>

#include "protocol.h"

static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req;

	req = kunit_kzalloc(test, sizeof(struct mptcp_subflow_request_sock),
			    GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req);
	mptcp_token_init_request((struct request_sock *)req);
	return req;
}

static void mptcp_token_test_req_basic(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *null_msk = NULL;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	KUNIT_EXPECT_NE(test, 0, (int)req->token);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(req->token));

	/* cleanup */
	mptcp_token_destroy_request((struct request_sock *)req);
}

static struct inet_connection_sock *build_icsk(struct kunit *test)
{
	struct inet_connection_sock *icsk;

	icsk = kunit_kzalloc(test, sizeof(struct inet_connection_sock),
			     GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk);
	return icsk;
}

static struct mptcp_subflow_context *build_ctx(struct kunit *test)
{
	struct mptcp_subflow_context *ctx;

	ctx = kunit_kzalloc(test, sizeof(struct mptcp_subflow_context),
			    GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx);
	return ctx;
}

static struct mptcp_sock *build_msk(struct kunit *test)
{
	struct mptcp_sock *msk;

	msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
	refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
	return msk;
}

static void mptcp_token_test_msk_basic(struct kunit *test)
{
	struct inet_connection_sock *icsk = build_icsk(test);
	struct mptcp_subflow_context *ctx = build_ctx(test);
	struct mptcp_sock *msk = build_msk(test);
	struct mptcp_sock *null_msk = NULL;
	struct sock *sk;

	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
	ctx->conn = (struct sock *)msk;
	sk = (struct sock *)msk;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_connect((struct sock *)icsk));
	KUNIT_EXPECT_NE(test, 0, (int)ctx->token);
	KUNIT_EXPECT_EQ(test, ctx->token, msk->token);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(ctx->token));
	KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt));

	mptcp_token_destroy(msk);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(ctx->token));
}

static void mptcp_token_test_accept(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *msk = build_msk(test);

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	msk->token = req->token;
	mptcp_token_accept(req, msk);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));

	/* this is now a no-op */
	mptcp_token_destroy_request((struct request_sock *)req);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));

	/* cleanup */
	mptcp_token_destroy(msk);
}

static void mptcp_token_test_destroyed(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *msk = build_msk(test);
	struct mptcp_sock *null_msk = NULL;
	struct sock *sk;

	sk = (struct sock *)msk;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	msk->token = req->token;
	mptcp_token_accept(req, msk);

	/* simulate race on removal */
	refcount_set(&sk->sk_refcnt, 0);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(msk->token));

	/* cleanup */
	mptcp_token_destroy(msk);
}

static struct kunit_case mptcp_token_test_cases[] = {
	KUNIT_CASE(mptcp_token_test_req_basic),
	KUNIT_CASE(mptcp_token_test_msk_basic),
	KUNIT_CASE(mptcp_token_test_accept),
	KUNIT_CASE(mptcp_token_test_destroyed),
	{}
};

static struct kunit_suite mptcp_token_suite = {
	.name = "mptcp-token",
	.test_cases = mptcp_token_test_cases,
};

kunit_test_suite(mptcp_token_suite);

MODULE_LICENSE("GPL");