summaryrefslogtreecommitdiff
path: root/include/net/sock.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/net/sock.h')
-rw-r--r--include/net/sock.h46
1 files changed, 37 insertions, 9 deletions
diff --git a/include/net/sock.h b/include/net/sock.h
index b5cca7bae69b..6d84784d33fa 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1659,6 +1659,7 @@ void sock_rfree(struct sk_buff *skb);
void sock_efree(struct sk_buff *skb);
#ifdef CONFIG_INET
void sock_edemux(struct sk_buff *skb);
+void sock_pfree(struct sk_buff *skb);
#else
#define sock_edemux sock_efree
#endif
@@ -2526,16 +2527,14 @@ void sock_net_set(struct sock *sk, struct net *net)
write_pnet(&sk->sk_net, net);
}
-static inline struct sock *skb_steal_sock(struct sk_buff *skb)
+static inline bool
+skb_sk_is_prefetched(struct sk_buff *skb)
{
- if (skb->sk) {
- struct sock *sk = skb->sk;
-
- skb->destructor = NULL;
- skb->sk = NULL;
- return sk;
- }
- return NULL;
+#ifdef CONFIG_INET
+ return skb->destructor == sock_pfree;
+#else
+ return false;
+#endif /* CONFIG_INET */
}
/* This helper checks if a socket is a full socket,
@@ -2546,6 +2545,35 @@ static inline bool sk_fullsock(const struct sock *sk)
return (1 << sk->sk_state) & ~(TCPF_TIME_WAIT | TCPF_NEW_SYN_RECV);
}
+static inline bool
+sk_is_refcounted(struct sock *sk)
+{
+ /* Only full sockets have sk->sk_flags. */
+ return !sk_fullsock(sk) || !sock_flag(sk, SOCK_RCU_FREE);
+}
+
+/**
+ * skb_steal_sock
+ * @skb to steal the socket from
+ * @refcounted is set to true if the socket is reference-counted
+ */
+static inline struct sock *
+skb_steal_sock(struct sk_buff *skb, bool *refcounted)
+{
+ if (skb->sk) {
+ struct sock *sk = skb->sk;
+
+ *refcounted = true;
+ if (skb_sk_is_prefetched(skb))
+ *refcounted = sk_is_refcounted(sk);
+ skb->destructor = NULL;
+ skb->sk = NULL;
+ return sk;
+ }
+ *refcounted = false;
+ return NULL;
+}
+
/* Checks if this SKB belongs to an HW offloaded socket
* and whether any SW fallbacks are required based on dev.
* Check decrypted mark in case skb_orphan() cleared socket.