summaryrefslogtreecommitdiff
path: root/net/netlink/af_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r--net/netlink/af_netlink.c128
1 files changed, 62 insertions, 66 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 383631873748..642b9d382fb4 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -84,7 +84,7 @@ struct listeners {
static inline int netlink_is_kernel(struct sock *sk)
{
- return nlk_sk(sk)->flags & NETLINK_F_KERNEL_SOCKET;
+ return nlk_test_bit(KERNEL_SOCKET, sk);
}
struct netlink_table *nl_table __read_mostly;
@@ -349,9 +349,7 @@ static void netlink_deliver_tap_kernel(struct sock *dst, struct sock *src,
static void netlink_overrun(struct sock *sk)
{
- struct netlink_sock *nlk = nlk_sk(sk);
-
- if (!(nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)) {
+ if (!nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
if (!test_and_set_bit(NETLINK_S_CONGESTED,
&nlk_sk(sk)->state)) {
sk->sk_err = ENOBUFS;
@@ -677,6 +675,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
struct netlink_sock *nlk;
int (*bind)(struct net *net, int group);
void (*unbind)(struct net *net, int group);
+ void (*release)(struct sock *sock, unsigned long *groups);
int err = 0;
sock->state = SS_UNCONNECTED;
@@ -704,6 +703,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
cb_mutex = nl_table[protocol].cb_mutex;
bind = nl_table[protocol].bind;
unbind = nl_table[protocol].unbind;
+ release = nl_table[protocol].release;
netlink_unlock_table();
if (err < 0)
@@ -719,6 +719,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
nlk->module = module;
nlk->netlink_bind = bind;
nlk->netlink_unbind = unbind;
+ nlk->netlink_release = release;
out:
return err;
@@ -763,6 +764,8 @@ static int netlink_release(struct socket *sock)
* OK. Socket is unlinked, any packets that arrive now
* will be purged.
*/
+ if (nlk->netlink_release)
+ nlk->netlink_release(sk, nlk->groups);
/* must not acquire netlink_table_lock in any way again before unbind
* and notifying genetlink is done as otherwise it might deadlock
@@ -1402,9 +1405,7 @@ EXPORT_SYMBOL_GPL(netlink_has_listeners);
bool netlink_strict_get_check(struct sk_buff *skb)
{
- const struct netlink_sock *nlk = nlk_sk(NETLINK_CB(skb).sk);
-
- return nlk->flags & NETLINK_F_STRICT_CHK;
+ return nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
}
EXPORT_SYMBOL_GPL(netlink_strict_get_check);
@@ -1432,6 +1433,8 @@ struct netlink_broadcast_data {
int delivered;
gfp_t allocation;
struct sk_buff *skb, *skb2;
+ int (*tx_filter)(struct sock *dsk, struct sk_buff *skb, void *data);
+ void *tx_data;
};
static void do_one_broadcast(struct sock *sk,
@@ -1448,7 +1451,7 @@ static void do_one_broadcast(struct sock *sk,
return;
if (!net_eq(sock_net(sk), p->net)) {
- if (!(nlk->flags & NETLINK_F_LISTEN_ALL_NSID))
+ if (!nlk_test_bit(LISTEN_ALL_NSID, sk))
return;
if (!peernet_has_id(sock_net(sk), p->net))
@@ -1481,10 +1484,17 @@ static void do_one_broadcast(struct sock *sk,
netlink_overrun(sk);
/* Clone failed. Notify ALL listeners. */
p->failure = 1;
- if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
+ if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
p->delivery_failure = 1;
goto out;
}
+
+ if (p->tx_filter && p->tx_filter(sk, p->skb2, p->tx_data)) {
+ kfree_skb(p->skb2);
+ p->skb2 = NULL;
+ goto out;
+ }
+
if (sk_filter(sk, p->skb2)) {
kfree_skb(p->skb2);
p->skb2 = NULL;
@@ -1496,7 +1506,7 @@ static void do_one_broadcast(struct sock *sk,
val = netlink_broadcast_deliver(sk, p->skb2);
if (val < 0) {
netlink_overrun(sk);
- if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
+ if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
p->delivery_failure = 1;
} else {
p->congested |= val;
@@ -1507,8 +1517,12 @@ out:
sock_put(sk);
}
-int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 portid,
- u32 group, gfp_t allocation)
+int netlink_broadcast_filtered(struct sock *ssk, struct sk_buff *skb,
+ u32 portid,
+ u32 group, gfp_t allocation,
+ int (*filter)(struct sock *dsk,
+ struct sk_buff *skb, void *data),
+ void *filter_data)
{
struct net *net = sock_net(ssk);
struct netlink_broadcast_data info;
@@ -1527,6 +1541,8 @@ int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 portid,
info.allocation = allocation;
info.skb = skb;
info.skb2 = NULL;
+ info.tx_filter = filter;
+ info.tx_data = filter_data;
/* While we sleep in clone, do not allow to change socket list */
@@ -1552,6 +1568,14 @@ int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 portid,
}
return -ESRCH;
}
+EXPORT_SYMBOL(netlink_broadcast_filtered);
+
+int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 portid,
+ u32 group, gfp_t allocation)
+{
+ return netlink_broadcast_filtered(ssk, skb, portid, group, allocation,
+ NULL, NULL);
+}
EXPORT_SYMBOL(netlink_broadcast);
struct netlink_set_err_data {
@@ -1576,7 +1600,7 @@ static int do_one_set_err(struct sock *sk, struct netlink_set_err_data *p)
!test_bit(p->group - 1, nlk->groups))
goto out;
- if (p->code == ENOBUFS && nlk->flags & NETLINK_F_RECV_NO_ENOBUFS) {
+ if (p->code == ENOBUFS && nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
ret = 1;
goto out;
}
@@ -1629,10 +1653,7 @@ static void netlink_update_socket_mc(struct netlink_sock *nlk,
old = test_bit(group - 1, nlk->groups);
subscriptions = nlk->subscriptions - old + new;
- if (new)
- __set_bit(group - 1, nlk->groups);
- else
- __clear_bit(group - 1, nlk->groups);
+ __assign_bit(group - 1, nlk->groups, new);
netlink_update_subscriptions(&nlk->sk, subscriptions);
netlink_update_listeners(&nlk->sk);
}
@@ -1643,7 +1664,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
struct sock *sk = sock->sk;
struct netlink_sock *nlk = nlk_sk(sk);
unsigned int val = 0;
- int err;
+ int nr = -1;
if (level != SOL_NETLINK)
return -ENOPROTOOPT;
@@ -1654,14 +1675,12 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
switch (optname) {
case NETLINK_PKTINFO:
- if (val)
- nlk->flags |= NETLINK_F_RECV_PKTINFO;
- else
- nlk->flags &= ~NETLINK_F_RECV_PKTINFO;
- err = 0;
+ nr = NETLINK_F_RECV_PKTINFO;
break;
case NETLINK_ADD_MEMBERSHIP:
case NETLINK_DROP_MEMBERSHIP: {
+ int err;
+
if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV))
return -EPERM;
err = netlink_realloc_groups(sk);
@@ -1681,61 +1700,38 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
nlk->netlink_unbind(sock_net(sk), val);
- err = 0;
break;
}
case NETLINK_BROADCAST_ERROR:
- if (val)
- nlk->flags |= NETLINK_F_BROADCAST_SEND_ERROR;
- else
- nlk->flags &= ~NETLINK_F_BROADCAST_SEND_ERROR;
- err = 0;
+ nr = NETLINK_F_BROADCAST_SEND_ERROR;
break;
case NETLINK_NO_ENOBUFS:
+ assign_bit(NETLINK_F_RECV_NO_ENOBUFS, &nlk->flags, val);
if (val) {
- nlk->flags |= NETLINK_F_RECV_NO_ENOBUFS;
clear_bit(NETLINK_S_CONGESTED, &nlk->state);
wake_up_interruptible(&nlk->wait);
- } else {
- nlk->flags &= ~NETLINK_F_RECV_NO_ENOBUFS;
}
- err = 0;
break;
case NETLINK_LISTEN_ALL_NSID:
if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_BROADCAST))
return -EPERM;
-
- if (val)
- nlk->flags |= NETLINK_F_LISTEN_ALL_NSID;
- else
- nlk->flags &= ~NETLINK_F_LISTEN_ALL_NSID;
- err = 0;
+ nr = NETLINK_F_LISTEN_ALL_NSID;
break;
case NETLINK_CAP_ACK:
- if (val)
- nlk->flags |= NETLINK_F_CAP_ACK;
- else
- nlk->flags &= ~NETLINK_F_CAP_ACK;
- err = 0;
+ nr = NETLINK_F_CAP_ACK;
break;
case NETLINK_EXT_ACK:
- if (val)
- nlk->flags |= NETLINK_F_EXT_ACK;
- else
- nlk->flags &= ~NETLINK_F_EXT_ACK;
- err = 0;
+ nr = NETLINK_F_EXT_ACK;
break;
case NETLINK_GET_STRICT_CHK:
- if (val)
- nlk->flags |= NETLINK_F_STRICT_CHK;
- else
- nlk->flags &= ~NETLINK_F_STRICT_CHK;
- err = 0;
+ nr = NETLINK_F_STRICT_CHK;
break;
default:
- err = -ENOPROTOOPT;
+ return -ENOPROTOOPT;
}
- return err;
+ if (nr >= 0)
+ assign_bit(nr, &nlk->flags, val);
+ return 0;
}
static int netlink_getsockopt(struct socket *sock, int level, int optname,
@@ -1802,7 +1798,7 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
return -EINVAL;
len = sizeof(int);
- val = nlk->flags & flag ? 1 : 0;
+ val = test_bit(flag, &nlk->flags);
if (put_user(len, optlen) ||
copy_to_user(optval, &val, len))
@@ -1979,9 +1975,9 @@ static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
msg->msg_namelen = sizeof(*addr);
}
- if (nlk->flags & NETLINK_F_RECV_PKTINFO)
+ if (nlk_test_bit(RECV_PKTINFO, sk))
netlink_cmsg_recv_pktinfo(msg, skb);
- if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID)
+ if (nlk_test_bit(LISTEN_ALL_NSID, sk))
netlink_cmsg_listen_all_nsid(sk, msg, skb);
memset(&scm, 0, sizeof(scm));
@@ -2058,7 +2054,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
goto out_sock_release;
nlk = nlk_sk(sk);
- nlk->flags |= NETLINK_F_KERNEL_SOCKET;
+ set_bit(NETLINK_F_KERNEL_SOCKET, &nlk->flags);
netlink_table_grab();
if (!nl_table[unit].registered) {
@@ -2069,6 +2065,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
if (cfg) {
nl_table[unit].bind = cfg->bind;
nl_table[unit].unbind = cfg->unbind;
+ nl_table[unit].release = cfg->release;
nl_table[unit].flags = cfg->flags;
}
nl_table[unit].registered = 1;
@@ -2192,7 +2189,7 @@ static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb,
nl_dump_check_consistent(cb, nlh);
memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno));
- if (extack->_msg && nlk->flags & NETLINK_F_EXT_ACK) {
+ if (extack->_msg && test_bit(NETLINK_F_EXT_ACK, &nlk->flags)) {
nlh->nlmsg_flags |= NLM_F_ACK_TLVS;
if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg))
nlmsg_end(skb, nlh);
@@ -2321,8 +2318,8 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
const struct nlmsghdr *nlh,
struct netlink_dump_control *control)
{
- struct netlink_sock *nlk, *nlk2;
struct netlink_callback *cb;
+ struct netlink_sock *nlk;
struct sock *sk;
int ret;
@@ -2357,8 +2354,7 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
cb->min_dump_alloc = control->min_dump_alloc;
cb->skb = skb;
- nlk2 = nlk_sk(NETLINK_CB(skb).sk);
- cb->strict_check = !!(nlk2->flags & NETLINK_F_STRICT_CHK);
+ cb->strict_check = nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
if (control->start) {
cb->extack = control->extack;
@@ -2402,7 +2398,7 @@ netlink_ack_tlv_len(struct netlink_sock *nlk, int err,
{
size_t tlvlen;
- if (!extack || !(nlk->flags & NETLINK_F_EXT_ACK))
+ if (!extack || !test_bit(NETLINK_F_EXT_ACK, &nlk->flags))
return 0;
tlvlen = 0;
@@ -2474,7 +2470,7 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
* requests to cap the error message, and get extra error data if
* requested.
*/
- if (err && !(nlk->flags & NETLINK_F_CAP_ACK))
+ if (err && !test_bit(NETLINK_F_CAP_ACK, &nlk->flags))
payload += nlmsg_len(nlh);
else
flags |= NLM_F_CAPPED;