diff options
Diffstat (limited to 'net/can/raw.c')
-rw-r--r-- | net/can/raw.c | 79 |
1 files changed, 44 insertions, 35 deletions
diff --git a/net/can/raw.c b/net/can/raw.c index 15c79b079184..d50c3f3d892f 100644 --- a/net/can/raw.c +++ b/net/can/raw.c @@ -84,6 +84,8 @@ struct raw_sock { struct sock sk; int bound; int ifindex; + struct net_device *dev; + netdevice_tracker dev_tracker; struct list_head notifier; int loopback; int recv_own_msgs; @@ -277,21 +279,24 @@ static void raw_notify(struct raw_sock *ro, unsigned long msg, if (!net_eq(dev_net(dev), sock_net(sk))) return; - if (ro->ifindex != dev->ifindex) + if (ro->dev != dev) return; switch (msg) { case NETDEV_UNREGISTER: lock_sock(sk); /* remove current filters & unregister */ - if (ro->bound) + if (ro->bound) { raw_disable_allfilters(dev_net(dev), dev, sk); + netdev_put(dev, &ro->dev_tracker); + } if (ro->count > 1) kfree(ro->filter); ro->ifindex = 0; ro->bound = 0; + ro->dev = NULL; ro->count = 0; release_sock(sk); @@ -337,6 +342,7 @@ static int raw_init(struct sock *sk) ro->bound = 0; ro->ifindex = 0; + ro->dev = NULL; /* set default filter to single entry dfilter */ ro->dfilter.can_id = 0; @@ -383,18 +389,14 @@ static int raw_release(struct socket *sock) list_del(&ro->notifier); spin_unlock(&raw_notifier_lock); + rtnl_lock(); lock_sock(sk); /* remove current filters & unregister */ if (ro->bound) { - if (ro->ifindex) { - struct net_device *dev; - - dev = dev_get_by_index(sock_net(sk), ro->ifindex); - if (dev) { - raw_disable_allfilters(dev_net(dev), dev, sk); - dev_put(dev); - } + if (ro->dev) { + raw_disable_allfilters(dev_net(ro->dev), ro->dev, sk); + netdev_put(ro->dev, &ro->dev_tracker); } else { raw_disable_allfilters(sock_net(sk), NULL, sk); } @@ -405,6 +407,7 @@ static int raw_release(struct socket *sock) ro->ifindex = 0; ro->bound = 0; + ro->dev = NULL; ro->count = 0; free_percpu(ro->uniq); @@ -412,6 +415,8 @@ static int raw_release(struct socket *sock) sock->sk = NULL; release_sock(sk); + rtnl_unlock(); + sock_put(sk); return 0; @@ -422,6 +427,7 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len) struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; struct sock *sk = sock->sk; struct raw_sock *ro = raw_sk(sk); + struct net_device *dev = NULL; int ifindex; int err = 0; int notify_enetdown = 0; @@ -431,24 +437,23 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len) if (addr->can_family != AF_CAN) return -EINVAL; + rtnl_lock(); lock_sock(sk); if (ro->bound && addr->can_ifindex == ro->ifindex) goto out; if (addr->can_ifindex) { - struct net_device *dev; - dev = dev_get_by_index(sock_net(sk), addr->can_ifindex); if (!dev) { err = -ENODEV; goto out; } if (dev->type != ARPHRD_CAN) { - dev_put(dev); err = -ENODEV; - goto out; + goto out_put_dev; } + if (!(dev->flags & IFF_UP)) notify_enetdown = 1; @@ -456,7 +461,9 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len) /* filters set by default/setsockopt */ err = raw_enable_allfilters(sock_net(sk), dev, sk); - dev_put(dev); + if (err) + goto out_put_dev; + } else { ifindex = 0; @@ -467,26 +474,30 @@ static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len) if (!err) { if (ro->bound) { /* unregister old filters */ - if (ro->ifindex) { - struct net_device *dev; - - dev = dev_get_by_index(sock_net(sk), - ro->ifindex); - if (dev) { - raw_disable_allfilters(dev_net(dev), - dev, sk); - dev_put(dev); - } + if (ro->dev) { + raw_disable_allfilters(dev_net(ro->dev), + ro->dev, sk); + /* drop reference to old ro->dev */ + netdev_put(ro->dev, &ro->dev_tracker); } else { raw_disable_allfilters(sock_net(sk), NULL, sk); } } ro->ifindex = ifindex; ro->bound = 1; + /* bind() ok -> hold a reference for new ro->dev */ + ro->dev = dev; + if (ro->dev) + netdev_hold(ro->dev, &ro->dev_tracker, GFP_KERNEL); } - out: +out_put_dev: + /* remove potential reference from dev_get_by_index() */ + if (dev) + dev_put(dev); +out: release_sock(sk); + rtnl_unlock(); if (notify_enetdown) { sk->sk_err = ENETDOWN; @@ -553,9 +564,9 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, rtnl_lock(); lock_sock(sk); - if (ro->bound && ro->ifindex) { - dev = dev_get_by_index(sock_net(sk), ro->ifindex); - if (!dev) { + dev = ro->dev; + if (ro->bound && dev) { + if (dev->reg_state != NETREG_REGISTERED) { if (count > 1) kfree(filter); err = -ENODEV; @@ -596,7 +607,6 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, ro->count = count; out_fil: - dev_put(dev); release_sock(sk); rtnl_unlock(); @@ -614,9 +624,9 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, rtnl_lock(); lock_sock(sk); - if (ro->bound && ro->ifindex) { - dev = dev_get_by_index(sock_net(sk), ro->ifindex); - if (!dev) { + dev = ro->dev; + if (ro->bound && dev) { + if (dev->reg_state != NETREG_REGISTERED) { err = -ENODEV; goto out_err; } @@ -640,7 +650,6 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, ro->err_mask = err_mask; out_err: - dev_put(dev); release_sock(sk); rtnl_unlock(); @@ -873,7 +882,7 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) skb->dev = dev; skb->priority = sk->sk_priority; - skb->mark = sk->sk_mark; + skb->mark = READ_ONCE(sk->sk_mark); skb->tstamp = sockc.transmit_time; skb_setup_tx_timestamp(skb, sockc.tsflags); |