summaryrefslogtreecommitdiff
path: root/net/handshake/netlink.c
blob: 64a0046dd611c1300a5749455da62a30606783c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Generic netlink handshake service
 *
 * Author: Chuck Lever <chuck.lever@oracle.com>
 *
 * Copyright (c) 2023, Oracle and/or its affiliates.
 */

#include <linux/types.h>
#include <linux/socket.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/mm.h>

#include <net/sock.h>
#include <net/genetlink.h>
#include <net/netns/generic.h>

#include <kunit/visibility.h>

#include <uapi/linux/handshake.h>
#include "handshake.h"
#include "genl.h"

#include <trace/events/handshake.h>

/**
 * handshake_genl_notify - Notify handlers that a request is waiting
 * @net: target network namespace
 * @proto: handshake protocol
 * @flags: memory allocation control flags
 *
 * Returns zero on success or a negative errno if notification failed.
 */
int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
			  gfp_t flags)
{
	struct sk_buff *msg;
	void *hdr;

	/* Disable notifications during unit testing */
	if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
		return 0;

	if (!genl_has_listeners(&handshake_nl_family, net,
				proto->hp_handler_class))
		return -ESRCH;

	msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
	if (!msg)
		return -ENOMEM;

	hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
			  HANDSHAKE_CMD_READY);
	if (!hdr)
		goto out_free;

	if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
			proto->hp_handler_class) < 0) {
		genlmsg_cancel(msg, hdr);
		goto out_free;
	}

	genlmsg_end(msg, hdr);
	return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
				       0, proto->hp_handler_class, flags);

out_free:
	nlmsg_free(msg);
	return -EMSGSIZE;
}

/**
 * handshake_genl_put - Create a generic netlink message header
 * @msg: buffer in which to create the header
 * @info: generic netlink message context
 *
 * Returns a ready-to-use header, or NULL.
 */
struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
				    struct genl_info *info)
{
	return genlmsg_put(msg, info->snd_portid, info->snd_seq,
			   &handshake_nl_family, 0, info->genlhdr->cmd);
}
EXPORT_SYMBOL(handshake_genl_put);

/*
 * dup() a kernel socket for use as a user space file descriptor
 * in the current process. The kernel socket must have an
 * instatiated struct file.
 *
 * Implicit argument: "current()"
 */
static int handshake_dup(struct socket *sock)
{
	struct file *file;
	int newfd;

	file = get_file(sock->file);
	newfd = get_unused_fd_flags(O_CLOEXEC);
	if (newfd < 0) {
		fput(file);
		return newfd;
	}

	fd_install(newfd, file);
	return newfd;
}

int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
{
	struct net *net = sock_net(skb->sk);
	struct handshake_net *hn = handshake_pernet(net);
	struct handshake_req *req = NULL;
	struct socket *sock;
	int class, fd, err;

	err = -EOPNOTSUPP;
	if (!hn)
		goto out_status;

	err = -EINVAL;
	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
		goto out_status;
	class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);

	err = -EAGAIN;
	req = handshake_req_next(hn, class);
	if (!req)
		goto out_status;

	sock = req->hr_sk->sk_socket;
	fd = handshake_dup(sock);
	if (fd < 0) {
		err = fd;
		goto out_complete;
	}
	err = req->hr_proto->hp_accept(req, info, fd);
	if (err) {
		fput(sock->file);
		goto out_complete;
	}

	trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
	return 0;

out_complete:
	handshake_complete(req, -EIO, NULL);
out_status:
	trace_handshake_cmd_accept_err(net, req, NULL, err);
	return err;
}

int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
{
	struct net *net = sock_net(skb->sk);
	struct handshake_req *req;
	struct socket *sock;
	int fd, status, err;

	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
		return -EINVAL;
	fd = nla_get_s32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);

	sock = sockfd_lookup(fd, &err);
	if (!sock)
		return err;

	req = handshake_req_hash_lookup(sock->sk);
	if (!req) {
		err = -EBUSY;
		trace_handshake_cmd_done_err(net, req, sock->sk, err);
		fput(sock->file);
		return err;
	}

	trace_handshake_cmd_done(net, req, sock->sk, fd);

	status = -EIO;
	if (info->attrs[HANDSHAKE_A_DONE_STATUS])
		status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);

	handshake_complete(req, status, info);
	fput(sock->file);
	return 0;
}

static unsigned int handshake_net_id;

static int __net_init handshake_net_init(struct net *net)
{
	struct handshake_net *hn = net_generic(net, handshake_net_id);
	unsigned long tmp;
	struct sysinfo si;

	/*
	 * Arbitrary limit to prevent handshakes that do not make
	 * progress from clogging up the system. The cap scales up
	 * with the amount of physical memory on the system.
	 */
	si_meminfo(&si);
	tmp = si.totalram / (25 * si.mem_unit);
	hn->hn_pending_max = clamp(tmp, 3UL, 50UL);

	spin_lock_init(&hn->hn_lock);
	hn->hn_pending = 0;
	hn->hn_flags = 0;
	INIT_LIST_HEAD(&hn->hn_requests);
	return 0;
}

static void __net_exit handshake_net_exit(struct net *net)
{
	struct handshake_net *hn = net_generic(net, handshake_net_id);
	struct handshake_req *req;
	LIST_HEAD(requests);

	/*
	 * Drain the net's pending list. Requests that have been
	 * accepted and are in progress will be destroyed when
	 * the socket is closed.
	 */
	spin_lock(&hn->hn_lock);
	set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
	list_splice_init(&requests, &hn->hn_requests);
	spin_unlock(&hn->hn_lock);

	while (!list_empty(&requests)) {
		req = list_first_entry(&requests, struct handshake_req, hr_list);
		list_del(&req->hr_list);

		/*
		 * Requests on this list have not yet been
		 * accepted, so they do not have an fd to put.
		 */

		handshake_complete(req, -ETIMEDOUT, NULL);
	}
}

static struct pernet_operations handshake_genl_net_ops = {
	.init		= handshake_net_init,
	.exit		= handshake_net_exit,
	.id		= &handshake_net_id,
	.size		= sizeof(struct handshake_net),
};

/**
 * handshake_pernet - Get the handshake private per-net structure
 * @net: network namespace
 *
 * Returns a pointer to the net's private per-net structure for the
 * handshake module, or NULL if handshake_init() failed.
 */
struct handshake_net *handshake_pernet(struct net *net)
{
	return handshake_net_id ?
		net_generic(net, handshake_net_id) : NULL;
}
EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);

static int __init handshake_init(void)
{
	int ret;

	ret = handshake_req_hash_init();
	if (ret) {
		pr_warn("handshake: hash initialization failed (%d)\n", ret);
		return ret;
	}

	ret = genl_register_family(&handshake_nl_family);
	if (ret) {
		pr_warn("handshake: netlink registration failed (%d)\n", ret);
		handshake_req_hash_destroy();
		return ret;
	}

	/*
	 * ORDER: register_pernet_subsys must be done last.
	 *
	 *	If initialization does not make it past pernet_subsys
	 *	registration, then handshake_net_id will remain 0. That
	 *	shunts the handshake consumer API to return ENOTSUPP
	 *	to prevent it from dereferencing something that hasn't
	 *	been allocated.
	 */
	ret = register_pernet_subsys(&handshake_genl_net_ops);
	if (ret) {
		pr_warn("handshake: pernet registration failed (%d)\n", ret);
		genl_unregister_family(&handshake_nl_family);
		handshake_req_hash_destroy();
	}

	return ret;
}

static void __exit handshake_exit(void)
{
	unregister_pernet_subsys(&handshake_genl_net_ops);
	handshake_net_id = 0;

	handshake_req_hash_destroy();
	genl_unregister_family(&handshake_nl_family);
}

module_init(handshake_init);
module_exit(handshake_exit);