summaryrefslogtreecommitdiff
path: root/net/core
diff options
context:
space:
mode:
Diffstat (limited to 'net/core')
-rw-r--r--net/core/sock_map.c106
1 files changed, 101 insertions, 5 deletions
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index fafcbd22ecba..cb240d87e068 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -141,6 +141,51 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
}
}
+static int sock_map_init_proto(struct sock *sk)
+{
+ struct sk_psock *psock;
+ struct proto *prot;
+
+ sock_owned_by_me(sk);
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+
+ prot = tcp_bpf_get_proto(sk, psock);
+ if (IS_ERR(prot)) {
+ rcu_read_unlock();
+ return PTR_ERR(prot);
+ }
+
+ sk_psock_update_proto(sk, psock, prot);
+ rcu_read_unlock();
+ return 0;
+}
+
+static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (psock) {
+ if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
+ psock = ERR_PTR(-EBUSY);
+ goto out;
+ }
+
+ if (!refcount_inc_not_zero(&psock->refcnt))
+ psock = ERR_PTR(-EBUSY);
+ }
+out:
+ rcu_read_unlock();
+ return psock;
+}
+
static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk)
{
@@ -172,7 +217,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
}
}
- psock = sk_psock_get_checked(sk);
+ psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) {
ret = PTR_ERR(psock);
goto out_progs;
@@ -196,7 +241,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
if (msg_parser)
psock_set_prog(&psock->progs.msg_parser, msg_parser);
- ret = tcp_bpf_init(sk);
+ ret = sock_map_init_proto(sk);
if (ret < 0)
goto out_drop;
@@ -231,7 +276,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
struct sk_psock *psock;
int ret;
- psock = sk_psock_get_checked(sk);
+ psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock))
return PTR_ERR(psock);
@@ -241,7 +286,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
return -ENOMEM;
}
- ret = tcp_bpf_init(sk);
+ ret = sock_map_init_proto(sk);
if (ret < 0)
sk_psock_put(sk, psock);
return ret;
@@ -1120,7 +1165,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
return 0;
}
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
{
switch (link->map->map_type) {
case BPF_MAP_TYPE_SOCKMAP:
@@ -1133,3 +1178,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
break;
}
}
+
+static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_link *link;
+
+ while ((link = sk_psock_link_pop(psock))) {
+ sock_map_unlink(sk, link);
+ sk_psock_free_link(link);
+ }
+}
+
+void sock_map_unhash(struct sock *sk)
+{
+ void (*saved_unhash)(struct sock *sk);
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
+ return;
+ }
+
+ saved_unhash = psock->saved_unhash;
+ sock_map_remove_links(sk, psock);
+ rcu_read_unlock();
+ saved_unhash(sk);
+}
+
+void sock_map_close(struct sock *sk, long timeout)
+{
+ void (*saved_close)(struct sock *sk, long timeout);
+ struct sk_psock *psock;
+
+ lock_sock(sk);
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ release_sock(sk);
+ return sk->sk_prot->close(sk, timeout);
+ }
+
+ saved_close = psock->saved_close;
+ sock_map_remove_links(sk, psock);
+ rcu_read_unlock();
+ release_sock(sk);
+ saved_close(sk, timeout);
+}