From 9652e931e73be7e54a9c40e9bcd4bbdafe92a406 Mon Sep 17 00:00:00 2001
From: Patrick McHardy <kaber@trash.net>
Date: Wed, 17 Apr 2013 06:47:02 +0000
Subject: netlink: add mmap'ed netlink helper functions

Add helper functions for looking up mmap'ed frame headers, reading and
writing their status, allocating skbs with mmap'ed data areas and a poll
function.

Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
---
 net/netlink/af_netlink.c | 185 ++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 183 insertions(+), 2 deletions(-)

(limited to 'net/netlink')

diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 1d3c7128e90e..6560635fd25c 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -56,6 +56,7 @@
 #include <linux/audit.h>
 #include <linux/mutex.h>
 #include <linux/vmalloc.h>
+#include <asm/cacheflush.h>
 
 #include <net/net_namespace.h>
 #include <net/sock.h>
@@ -89,6 +90,7 @@ EXPORT_SYMBOL_GPL(nl_table);
 static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait);
 
 static int netlink_dump(struct sock *sk);
+static void netlink_skb_destructor(struct sk_buff *skb);
 
 DEFINE_RWLOCK(nl_table_lock);
 EXPORT_SYMBOL_GPL(nl_table_lock);
@@ -109,6 +111,11 @@ static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u
 }
 
 #ifdef CONFIG_NETLINK_MMAP
+static bool netlink_skb_is_mmaped(const struct sk_buff *skb)
+{
+	return NETLINK_CB(skb).flags & NETLINK_SKB_MMAPED;
+}
+
 static __pure struct page *pgvec_to_page(const void *addr)
 {
 	if (is_vmalloc_addr(addr))
@@ -332,8 +339,154 @@ out:
 	mutex_unlock(&nlk->pg_vec_lock);
 	return 0;
 }
+
+static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr)
+{
+#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1
+	struct page *p_start, *p_end;
+
+	/* First page is flushed through netlink_{get,set}_status */
+	p_start = pgvec_to_page(hdr + PAGE_SIZE);
+	p_end   = pgvec_to_page((void *)hdr + NL_MMAP_MSG_HDRLEN + hdr->nm_len - 1);
+	while (p_start <= p_end) {
+		flush_dcache_page(p_start);
+		p_start++;
+	}
+#endif
+}
+
+static enum nl_mmap_status netlink_get_status(const struct nl_mmap_hdr *hdr)
+{
+	smp_rmb();
+	flush_dcache_page(pgvec_to_page(hdr));
+	return hdr->nm_status;
+}
+
+static void netlink_set_status(struct nl_mmap_hdr *hdr,
+			       enum nl_mmap_status status)
+{
+	hdr->nm_status = status;
+	flush_dcache_page(pgvec_to_page(hdr));
+	smp_wmb();
+}
+
+static struct nl_mmap_hdr *
+__netlink_lookup_frame(const struct netlink_ring *ring, unsigned int pos)
+{
+	unsigned int pg_vec_pos, frame_off;
+
+	pg_vec_pos = pos / ring->frames_per_block;
+	frame_off  = pos % ring->frames_per_block;
+
+	return ring->pg_vec[pg_vec_pos] + (frame_off * ring->frame_size);
+}
+
+static struct nl_mmap_hdr *
+netlink_lookup_frame(const struct netlink_ring *ring, unsigned int pos,
+		     enum nl_mmap_status status)
+{
+	struct nl_mmap_hdr *hdr;
+
+	hdr = __netlink_lookup_frame(ring, pos);
+	if (netlink_get_status(hdr) != status)
+		return NULL;
+
+	return hdr;
+}
+
+static struct nl_mmap_hdr *
+netlink_current_frame(const struct netlink_ring *ring,
+		      enum nl_mmap_status status)
+{
+	return netlink_lookup_frame(ring, ring->head, status);
+}
+
+static struct nl_mmap_hdr *
+netlink_previous_frame(const struct netlink_ring *ring,
+		       enum nl_mmap_status status)
+{
+	unsigned int prev;
+
+	prev = ring->head ? ring->head - 1 : ring->frame_max;
+	return netlink_lookup_frame(ring, prev, status);
+}
+
+static void netlink_increment_head(struct netlink_ring *ring)
+{
+	ring->head = ring->head != ring->frame_max ? ring->head + 1 : 0;
+}
+
+static void netlink_forward_ring(struct netlink_ring *ring)
+{
+	unsigned int head = ring->head, pos = head;
+	const struct nl_mmap_hdr *hdr;
+
+	do {
+		hdr = __netlink_lookup_frame(ring, pos);
+		if (hdr->nm_status == NL_MMAP_STATUS_UNUSED)
+			break;
+		if (hdr->nm_status != NL_MMAP_STATUS_SKIP)
+			break;
+		netlink_increment_head(ring);
+	} while (ring->head != head);
+}
+
+static unsigned int netlink_poll(struct file *file, struct socket *sock,
+				 poll_table *wait)
+{
+	struct sock *sk = sock->sk;
+	struct netlink_sock *nlk = nlk_sk(sk);
+	unsigned int mask;
+
+	mask = datagram_poll(file, sock, wait);
+
+	spin_lock_bh(&sk->sk_receive_queue.lock);
+	if (nlk->rx_ring.pg_vec) {
+		netlink_forward_ring(&nlk->rx_ring);
+		if (!netlink_previous_frame(&nlk->rx_ring, NL_MMAP_STATUS_UNUSED))
+			mask |= POLLIN | POLLRDNORM;
+	}
+	spin_unlock_bh(&sk->sk_receive_queue.lock);
+
+	spin_lock_bh(&sk->sk_write_queue.lock);
+	if (nlk->tx_ring.pg_vec) {
+		if (netlink_current_frame(&nlk->tx_ring, NL_MMAP_STATUS_UNUSED))
+			mask |= POLLOUT | POLLWRNORM;
+	}
+	spin_unlock_bh(&sk->sk_write_queue.lock);
+
+	return mask;
+}
+
+static struct nl_mmap_hdr *netlink_mmap_hdr(struct sk_buff *skb)
+{
+	return (struct nl_mmap_hdr *)(skb->head - NL_MMAP_HDRLEN);
+}
+
+static void netlink_ring_setup_skb(struct sk_buff *skb, struct sock *sk,
+				   struct netlink_ring *ring,
+				   struct nl_mmap_hdr *hdr)
+{
+	unsigned int size;
+	void *data;
+
+	size = ring->frame_size - NL_MMAP_HDRLEN;
+	data = (void *)hdr + NL_MMAP_HDRLEN;
+
+	skb->head	= data;
+	skb->data	= data;
+	skb_reset_tail_pointer(skb);
+	skb->end	= skb->tail + size;
+	skb->len	= 0;
+
+	skb->destructor	= netlink_skb_destructor;
+	NETLINK_CB(skb).flags |= NETLINK_SKB_MMAPED;
+	NETLINK_CB(skb).sk = sk;
+}
 #else /* CONFIG_NETLINK_MMAP */
+#define netlink_skb_is_mmaped(skb)	false
 #define netlink_mmap			sock_no_mmap
+#define netlink_poll			datagram_poll
 #endif /* CONFIG_NETLINK_MMAP */
 
 static void netlink_destroy_callback(struct netlink_callback *cb)
@@ -350,7 +503,35 @@ static void netlink_consume_callback(struct netlink_callback *cb)
 
 static void netlink_skb_destructor(struct sk_buff *skb)
 {
-	sock_rfree(skb);
+#ifdef CONFIG_NETLINK_MMAP
+	struct nl_mmap_hdr *hdr;
+	struct netlink_ring *ring;
+	struct sock *sk;
+
+	/* If a packet from the kernel to userspace was freed because of an
+	 * error without being delivered to userspace, the kernel must reset
+	 * the status. In the direction userspace to kernel, the status is
+	 * always reset here after the packet was processed and freed.
+	 */
+	if (netlink_skb_is_mmaped(skb)) {
+		hdr = netlink_mmap_hdr(skb);
+		sk = NETLINK_CB(skb).sk;
+
+		if (!(NETLINK_CB(skb).flags & NETLINK_SKB_DELIVERED)) {
+			hdr->nm_len = 0;
+			netlink_set_status(hdr, NL_MMAP_STATUS_VALID);
+		}
+		ring = &nlk_sk(sk)->rx_ring;
+
+		WARN_ON(atomic_read(&ring->pending) == 0);
+		atomic_dec(&ring->pending);
+		sock_put(sk);
+
+		skb->data = NULL;
+	}
+#endif
+	if (skb->sk != NULL)
+		sock_rfree(skb);
 }
 
 static void netlink_skb_set_owner_r(struct sk_buff *skb, struct sock *sk)
@@ -2349,7 +2530,7 @@ static const struct proto_ops netlink_ops = {
 	.socketpair =	sock_no_socketpair,
 	.accept =	sock_no_accept,
 	.getname =	netlink_getname,
-	.poll =		datagram_poll,
+	.poll =		netlink_poll,
 	.ioctl =	sock_no_ioctl,
 	.listen =	sock_no_listen,
 	.shutdown =	sock_no_shutdown,
-- 
cgit v1.2.3