diff options
Diffstat (limited to 'net')
554 files changed, 20076 insertions, 12908 deletions
diff --git a/net/802/garp.c b/net/802/garp.c index fc9eb02a912f..ab24b21fbb49 100644 --- a/net/802/garp.c +++ b/net/802/garp.c @@ -407,7 +407,7 @@ static void garp_join_timer_arm(struct garp_applicant *app) { unsigned long delay; - delay = prandom_u32_max(msecs_to_jiffies(garp_join_time)); + delay = get_random_u32_below(msecs_to_jiffies(garp_join_time)); mod_timer(&app->join_timer, jiffies + delay); } @@ -618,7 +618,7 @@ void garp_uninit_applicant(struct net_device *dev, struct garp_application *appl /* Delete timer and generate a final TRANSMIT_PDU event to flush out * all pending messages before the applicant is gone. */ - del_timer_sync(&app->join_timer); + timer_shutdown_sync(&app->join_timer); spin_lock_bh(&app->lock); garp_gid_event(app, GARP_EVENT_TRANSMIT_PDU); diff --git a/net/802/mrp.c b/net/802/mrp.c index 155f74d8b14f..eafc21ecc287 100644 --- a/net/802/mrp.c +++ b/net/802/mrp.c @@ -592,7 +592,7 @@ static void mrp_join_timer_arm(struct mrp_applicant *app) { unsigned long delay; - delay = prandom_u32_max(msecs_to_jiffies(mrp_join_time)); + delay = get_random_u32_below(msecs_to_jiffies(mrp_join_time)); mod_timer(&app->join_timer, jiffies + delay); } @@ -606,7 +606,10 @@ static void mrp_join_timer(struct timer_list *t) spin_unlock(&app->lock); mrp_queue_xmit(app); - mrp_join_timer_arm(app); + spin_lock(&app->lock); + if (likely(app->active)) + mrp_join_timer_arm(app); + spin_unlock(&app->lock); } static void mrp_periodic_timer_arm(struct mrp_applicant *app) @@ -620,11 +623,12 @@ static void mrp_periodic_timer(struct timer_list *t) struct mrp_applicant *app = from_timer(app, t, periodic_timer); spin_lock(&app->lock); - mrp_mad_event(app, MRP_EVENT_PERIODIC); - mrp_pdu_queue(app); + if (likely(app->active)) { + mrp_mad_event(app, MRP_EVENT_PERIODIC); + mrp_pdu_queue(app); + mrp_periodic_timer_arm(app); + } spin_unlock(&app->lock); - - mrp_periodic_timer_arm(app); } static int mrp_pdu_parse_end_mark(struct sk_buff *skb, int *offset) @@ -872,6 +876,7 @@ int mrp_init_applicant(struct net_device *dev, struct mrp_application *appl) app->dev = dev; app->app = appl; app->mad = RB_ROOT; + app->active = true; spin_lock_init(&app->lock); skb_queue_head_init(&app->queue); rcu_assign_pointer(dev->mrp_port->applicants[appl->type], app); @@ -900,11 +905,14 @@ void mrp_uninit_applicant(struct net_device *dev, struct mrp_application *appl) RCU_INIT_POINTER(port->applicants[appl->type], NULL); + spin_lock_bh(&app->lock); + app->active = false; + spin_unlock_bh(&app->lock); /* Delete timer and generate a final TX event to flush out * all pending messages before the applicant is gone. */ - del_timer_sync(&app->join_timer); - del_timer_sync(&app->periodic_timer); + timer_shutdown_sync(&app->join_timer); + timer_shutdown_sync(&app->periodic_timer); spin_lock_bh(&app->lock); mrp_mad_event(app, MRP_EVENT_TX); diff --git a/net/8021q/vlan_dev.c b/net/8021q/vlan_dev.c index e1bb41a443c4..296d0145932f 100644 --- a/net/8021q/vlan_dev.c +++ b/net/8021q/vlan_dev.c @@ -712,13 +712,13 @@ static void vlan_dev_get_stats64(struct net_device *dev, p = per_cpu_ptr(vlan_dev_priv(dev)->vlan_pcpu_stats, i); do { - start = u64_stats_fetch_begin_irq(&p->syncp); + start = u64_stats_fetch_begin(&p->syncp); rxpackets = u64_stats_read(&p->rx_packets); rxbytes = u64_stats_read(&p->rx_bytes); rxmulticast = u64_stats_read(&p->rx_multicast); txpackets = u64_stats_read(&p->tx_packets); txbytes = u64_stats_read(&p->tx_bytes); - } while (u64_stats_fetch_retry_irq(&p->syncp, start)); + } while (u64_stats_fetch_retry(&p->syncp, start)); stats->rx_packets += rxpackets; stats->rx_bytes += rxbytes; diff --git a/net/9p/client.c b/net/9p/client.c index aaa37b07e30a..622ec6a586ee 100644 --- a/net/9p/client.c +++ b/net/9p/client.c @@ -297,6 +297,11 @@ p9_tag_alloc(struct p9_client *c, int8_t type, uint t_size, uint r_size, p9pdu_reset(&req->rc); req->t_err = 0; req->status = REQ_STATUS_ALLOC; + /* refcount needs to be set to 0 before inserting into the idr + * so p9_tag_lookup does not accept a request that is not fully + * initialized. refcount_set to 2 below will mark request ready. + */ + refcount_set(&req->refcount, 0); init_waitqueue_head(&req->wq); INIT_LIST_HEAD(&req->req_list); @@ -438,7 +443,7 @@ void p9_client_cb(struct p9_client *c, struct p9_req_t *req, int status) * the status change is visible to another thread */ smp_wmb(); - req->status = status; + WRITE_ONCE(req->status, status); wake_up(&req->wq); p9_debug(P9_DEBUG_MUX, "wakeup: %d\n", req->tc.tag); @@ -514,10 +519,9 @@ static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) int ecode; err = p9_parse_header(&req->rc, NULL, &type, NULL, 0); - if (req->rc.size >= c->msize) { - p9_debug(P9_DEBUG_ERROR, - "requested packet size too big: %d\n", - req->rc.size); + if (req->rc.size > req->rc.capacity && !req->rc.zc) { + pr_err("requested packet size too big: %d does not fit %zu (type=%d)\n", + req->rc.size, req->rc.capacity, req->rc.id); return -EIO; } /* dump the response from server @@ -600,7 +604,7 @@ static int p9_client_flush(struct p9_client *c, struct p9_req_t *oldreq) /* if we haven't received a response for oldreq, * remove it from the list */ - if (oldreq->status == REQ_STATUS_SENT) { + if (READ_ONCE(oldreq->status) == REQ_STATUS_SENT) { if (c->trans_mod->cancelled) c->trans_mod->cancelled(c, oldreq); } @@ -680,6 +684,9 @@ p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...) if (IS_ERR(req)) return req; + req->tc.zc = false; + req->rc.zc = false; + if (signal_pending(current)) { sigpending = 1; clear_thread_flag(TIF_SIGPENDING); @@ -697,7 +704,8 @@ p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...) } again: /* Wait for the response */ - err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD); + err = wait_event_killable(req->wq, + READ_ONCE(req->status) >= REQ_STATUS_RCVD); /* Make sure our req is coherent with regard to updates in other * threads - echoes to wmb() in the callback @@ -711,7 +719,7 @@ again: goto again; } - if (req->status == REQ_STATUS_ERROR) { + if (READ_ONCE(req->status) == REQ_STATUS_ERROR) { p9_debug(P9_DEBUG_ERROR, "req_status error %d\n", req->t_err); err = req->t_err; } @@ -724,7 +732,7 @@ again: p9_client_flush(c, req); /* if we received the response anyway, don't signal error */ - if (req->status == REQ_STATUS_RCVD) + if (READ_ONCE(req->status) == REQ_STATUS_RCVD) err = 0; } recalc_sigpending: @@ -778,6 +786,9 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, if (IS_ERR(req)) return req; + req->tc.zc = true; + req->rc.zc = true; + if (signal_pending(current)) { sigpending = 1; clear_thread_flag(TIF_SIGPENDING); @@ -793,7 +804,7 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, if (err != -ERESTARTSYS) goto recalc_sigpending; } - if (req->status == REQ_STATUS_ERROR) { + if (READ_ONCE(req->status) == REQ_STATUS_ERROR) { p9_debug(P9_DEBUG_ERROR, "req_status error %d\n", req->t_err); err = req->t_err; } @@ -806,7 +817,7 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, p9_client_flush(c, req); /* if we received the response anyway, don't signal error */ - if (req->status == REQ_STATUS_RCVD) + if (READ_ONCE(req->status) == REQ_STATUS_RCVD) err = 0; } recalc_sigpending: @@ -2043,7 +2054,7 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) struct kvec kv = {.iov_base = data, .iov_len = count}; struct iov_iter to; - iov_iter_kvec(&to, READ, &kv, 1, count); + iov_iter_kvec(&to, ITER_DEST, &kv, 1, count); p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n", fid->fid, offset, count); diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c index 56a186768750..00b684616e8d 100644 --- a/net/9p/trans_fd.c +++ b/net/9p/trans_fd.c @@ -20,7 +20,6 @@ #include <linux/un.h> #include <linux/uaccess.h> #include <linux/inet.h> -#include <linux/idr.h> #include <linux/file.h> #include <linux/parser.h> #include <linux/slab.h> @@ -120,7 +119,7 @@ struct p9_conn { struct list_head unsent_req_list; struct p9_req_t *rreq; struct p9_req_t *wreq; - char tmp_buf[7]; + char tmp_buf[P9_HDRSZ]; struct p9_fcall rc; int wpos; int wsize; @@ -202,9 +201,11 @@ static void p9_conn_cancel(struct p9_conn *m, int err) list_for_each_entry_safe(req, rtmp, &m->req_list, req_list) { list_move(&req->req_list, &cancel_list); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); } list_for_each_entry_safe(req, rtmp, &m->unsent_req_list, req_list) { list_move(&req->req_list, &cancel_list); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); } spin_unlock(&m->req_lock); @@ -291,7 +292,7 @@ static void p9_read_work(struct work_struct *work) if (!m->rc.sdata) { m->rc.sdata = m->tmp_buf; m->rc.offset = 0; - m->rc.capacity = 7; /* start by reading header */ + m->rc.capacity = P9_HDRSZ; /* start by reading header */ } clear_bit(Rpending, &m->wsched); @@ -314,7 +315,7 @@ static void p9_read_work(struct work_struct *work) p9_debug(P9_DEBUG_TRANS, "got new header\n"); /* Header size */ - m->rc.size = 7; + m->rc.size = P9_HDRSZ; err = p9_parse_header(&m->rc, &m->rc.size, NULL, NULL, 0); if (err) { p9_debug(P9_DEBUG_ERROR, @@ -322,14 +323,6 @@ static void p9_read_work(struct work_struct *work) goto error; } - if (m->rc.size >= m->client->msize) { - p9_debug(P9_DEBUG_ERROR, - "requested packet size too big: %d\n", - m->rc.size); - err = -EIO; - goto error; - } - p9_debug(P9_DEBUG_TRANS, "mux %p pkt: size: %d bytes tag: %d\n", m, m->rc.size, m->rc.tag); @@ -342,6 +335,14 @@ static void p9_read_work(struct work_struct *work) goto error; } + if (m->rc.size > m->rreq->rc.capacity) { + p9_debug(P9_DEBUG_ERROR, + "requested packet size too big: %d for tag %d with capacity %zd\n", + m->rc.size, m->rc.tag, m->rreq->rc.capacity); + err = -EIO; + goto error; + } + if (!m->rreq->rc.sdata) { p9_debug(P9_DEBUG_ERROR, "No recv fcall for tag %d (req %p), disconnecting!\n", @@ -465,7 +466,7 @@ static void p9_write_work(struct work_struct *work) req = list_entry(m->unsent_req_list.next, struct p9_req_t, req_list); - req->status = REQ_STATUS_SENT; + WRITE_ONCE(req->status, REQ_STATUS_SENT); p9_debug(P9_DEBUG_TRANS, "move req %p\n", req); list_move_tail(&req->req_list, &m->req_list); @@ -674,7 +675,7 @@ static int p9_fd_request(struct p9_client *client, struct p9_req_t *req) return m->err; spin_lock(&m->req_lock); - req->status = REQ_STATUS_UNSENT; + WRITE_ONCE(req->status, REQ_STATUS_UNSENT); list_add_tail(&req->req_list, &m->unsent_req_list); spin_unlock(&m->req_lock); @@ -701,7 +702,7 @@ static int p9_fd_cancel(struct p9_client *client, struct p9_req_t *req) if (req->status == REQ_STATUS_UNSENT) { list_del(&req->req_list); - req->status = REQ_STATUS_FLSHD; + WRITE_ONCE(req->status, REQ_STATUS_FLSHD); p9_req_put(client, req); ret = 0; } @@ -730,7 +731,7 @@ static int p9_fd_cancelled(struct p9_client *client, struct p9_req_t *req) * remove it from the list. */ list_del(&req->req_list); - req->status = REQ_STATUS_FLSHD; + WRITE_ONCE(req->status, REQ_STATUS_FLSHD); spin_unlock(&m->req_lock); p9_req_put(client, req); @@ -860,10 +861,13 @@ static int p9_socket_open(struct p9_client *client, struct socket *csocket) struct file *file; p = kzalloc(sizeof(struct p9_trans_fd), GFP_KERNEL); - if (!p) + if (!p) { + sock_release(csocket); return -ENOMEM; + } csocket->sk->sk_allocation = GFP_NOIO; + csocket->sk->sk_use_task_frag = false; file = sock_alloc_file(csocket, 0, NULL); if (IS_ERR(file)) { pr_err("%s (%d): failed to map fd\n", diff --git a/net/9p/trans_rdma.c b/net/9p/trans_rdma.c index 6ff706760676..83f9100d46bf 100644 --- a/net/9p/trans_rdma.c +++ b/net/9p/trans_rdma.c @@ -21,7 +21,6 @@ #include <linux/un.h> #include <linux/uaccess.h> #include <linux/inet.h> -#include <linux/idr.h> #include <linux/file.h> #include <linux/parser.h> #include <linux/semaphore.h> @@ -507,7 +506,7 @@ dont_need_post_recv: * because doing if after could erase the REQ_STATUS_RCVD * status in case of a very fast reply. */ - req->status = REQ_STATUS_SENT; + WRITE_ONCE(req->status, REQ_STATUS_SENT); err = ib_post_send(rdma->qp, &wr, NULL); if (err) goto send_error; @@ -517,7 +516,7 @@ dont_need_post_recv: /* Handle errors that happened during or while preparing the send: */ send_error: - req->status = REQ_STATUS_ERROR; + WRITE_ONCE(req->status, REQ_STATUS_ERROR); kfree(c); p9_debug(P9_DEBUG_ERROR, "Error %d in rdma_request()\n", err); diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c index e757f0601304..3c27ffb781e3 100644 --- a/net/9p/trans_virtio.c +++ b/net/9p/trans_virtio.c @@ -22,7 +22,6 @@ #include <linux/un.h> #include <linux/uaccess.h> #include <linux/inet.h> -#include <linux/idr.h> #include <linux/file.h> #include <linux/highmem.h> #include <linux/slab.h> @@ -263,7 +262,7 @@ p9_virtio_request(struct p9_client *client, struct p9_req_t *req) p9_debug(P9_DEBUG_TRANS, "9p debug: virtio request\n"); - req->status = REQ_STATUS_SENT; + WRITE_ONCE(req->status, REQ_STATUS_SENT); req_retry: spin_lock_irqsave(&chan->lock, flags); @@ -469,7 +468,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, inlen = n; } } - req->status = REQ_STATUS_SENT; + WRITE_ONCE(req->status, REQ_STATUS_SENT); req_retry_pinned: spin_lock_irqsave(&chan->lock, flags); @@ -532,9 +531,10 @@ req_retry_pinned: spin_unlock_irqrestore(&chan->lock, flags); kicked = 1; p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n"); - err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD); + err = wait_event_killable(req->wq, + READ_ONCE(req->status) >= REQ_STATUS_RCVD); // RERROR needs reply (== error string) in static data - if (req->status == REQ_STATUS_RCVD && + if (READ_ONCE(req->status) == REQ_STATUS_RCVD && unlikely(req->rc.sdata[4] == P9_RERROR)) handle_rerror(req, in_hdr_len, offs, in_pages); diff --git a/net/9p/trans_xen.c b/net/9p/trans_xen.c index b15c64128c3e..82c7005ede65 100644 --- a/net/9p/trans_xen.c +++ b/net/9p/trans_xen.c @@ -157,7 +157,7 @@ again: &masked_prod, masked_cons, XEN_9PFS_RING_SIZE(ring)); - p9_req->status = REQ_STATUS_SENT; + WRITE_ONCE(p9_req->status, REQ_STATUS_SENT); virt_wmb(); /* write ring before updating pointer */ prod += size; ring->intf->out_prod = prod; @@ -208,7 +208,17 @@ static void p9_xen_response(struct work_struct *work) continue; } - memcpy(&req->rc, &h, sizeof(h)); + if (h.size > req->rc.capacity) { + dev_warn(&priv->dev->dev, + "requested packet size too big: %d for tag %d with capacity %zd\n", + h.size, h.tag, req->rc.capacity); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); + goto recv_error; + } + + req->rc.size = h.size; + req->rc.id = h.id; + req->rc.tag = h.tag; req->rc.offset = 0; masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE(ring)); @@ -217,6 +227,7 @@ static void p9_xen_response(struct work_struct *work) masked_prod, &masked_cons, XEN_9PFS_RING_SIZE(ring)); +recv_error: virt_mb(); cons += h.size; ring->intf->in_cons = cons; @@ -294,13 +305,12 @@ static void xen_9pfs_front_free(struct xen_9pfs_front_priv *priv) kfree(priv); } -static int xen_9pfs_front_remove(struct xenbus_device *dev) +static void xen_9pfs_front_remove(struct xenbus_device *dev) { struct xen_9pfs_front_priv *priv = dev_get_drvdata(&dev->dev); dev_set_drvdata(&dev->dev, NULL); xen_9pfs_front_free(priv); - return 0; } static int xen_9pfs_front_alloc_dataring(struct xenbus_device *dev, diff --git a/net/atm/atm_sysfs.c b/net/atm/atm_sysfs.c index 0fdbdfd19474..466353b3dde4 100644 --- a/net/atm/atm_sysfs.c +++ b/net/atm/atm_sysfs.c @@ -108,9 +108,9 @@ static struct device_attribute *atm_attrs[] = { }; -static int atm_uevent(struct device *cdev, struct kobj_uevent_env *env) +static int atm_uevent(const struct device *cdev, struct kobj_uevent_env *env) { - struct atm_dev *adev; + const struct atm_dev *adev; if (!cdev) return -ENODEV; diff --git a/net/ax25/af_ax25.c b/net/ax25/af_ax25.c index 6b4c25a92377..d8da400cb4de 100644 --- a/net/ax25/af_ax25.c +++ b/net/ax25/af_ax25.c @@ -723,7 +723,7 @@ static int ax25_getsockopt(struct socket *sock, int level, int optname, if (maxlen < 1) return -EFAULT; - valptr = (void *) &val; + valptr = &val; length = min_t(unsigned int, maxlen, sizeof(int)); lock_sock(sk); @@ -785,7 +785,7 @@ static int ax25_getsockopt(struct socket *sock, int level, int optname, length = 1; } - valptr = (void *) devname; + valptr = devname; break; default: diff --git a/net/batman-adv/bat_iv_ogm.c b/net/batman-adv/bat_iv_ogm.c index 7f6a7c96ac92..114ee5da261f 100644 --- a/net/batman-adv/bat_iv_ogm.c +++ b/net/batman-adv/bat_iv_ogm.c @@ -280,7 +280,7 @@ batadv_iv_ogm_emit_send_time(const struct batadv_priv *bat_priv) unsigned int msecs; msecs = atomic_read(&bat_priv->orig_interval) - BATADV_JITTER; - msecs += prandom_u32_max(2 * BATADV_JITTER); + msecs += get_random_u32_below(2 * BATADV_JITTER); return jiffies + msecs_to_jiffies(msecs); } @@ -288,7 +288,7 @@ batadv_iv_ogm_emit_send_time(const struct batadv_priv *bat_priv) /* when do we schedule a ogm packet to be sent */ static unsigned long batadv_iv_ogm_fwd_send_time(void) { - return jiffies + msecs_to_jiffies(prandom_u32_max(BATADV_JITTER / 2)); + return jiffies + msecs_to_jiffies(get_random_u32_below(BATADV_JITTER / 2)); } /* apply hop penalty for a normal link */ diff --git a/net/batman-adv/bat_v_elp.c b/net/batman-adv/bat_v_elp.c index f1741fbfb617..f9a58fb5442e 100644 --- a/net/batman-adv/bat_v_elp.c +++ b/net/batman-adv/bat_v_elp.c @@ -51,7 +51,7 @@ static void batadv_v_elp_start_timer(struct batadv_hard_iface *hard_iface) unsigned int msecs; msecs = atomic_read(&hard_iface->bat_v.elp_interval) - BATADV_JITTER; - msecs += prandom_u32_max(2 * BATADV_JITTER); + msecs += get_random_u32_below(2 * BATADV_JITTER); queue_delayed_work(batadv_event_workqueue, &hard_iface->bat_v.elp_wq, msecs_to_jiffies(msecs)); diff --git a/net/batman-adv/bat_v_ogm.c b/net/batman-adv/bat_v_ogm.c index 033639df96d8..addfd8c4fe95 100644 --- a/net/batman-adv/bat_v_ogm.c +++ b/net/batman-adv/bat_v_ogm.c @@ -90,7 +90,7 @@ static void batadv_v_ogm_start_queue_timer(struct batadv_hard_iface *hard_iface) unsigned int msecs = BATADV_MAX_AGGREGATION_MS * 1000; /* msecs * [0.9, 1.1] */ - msecs += prandom_u32_max(msecs / 5) - (msecs / 10); + msecs += get_random_u32_below(msecs / 5) - (msecs / 10); queue_delayed_work(batadv_event_workqueue, &hard_iface->bat_v.aggr_wq, msecs_to_jiffies(msecs / 1000)); } @@ -109,7 +109,7 @@ static void batadv_v_ogm_start_timer(struct batadv_priv *bat_priv) return; msecs = atomic_read(&bat_priv->orig_interval) - BATADV_JITTER; - msecs += prandom_u32_max(2 * BATADV_JITTER); + msecs += get_random_u32_below(2 * BATADV_JITTER); queue_delayed_work(batadv_event_workqueue, &bat_priv->bat_v.ogm_wq, msecs_to_jiffies(msecs)); } diff --git a/net/batman-adv/netlink.c b/net/batman-adv/netlink.c index a5e4a4e976cf..ad5714f737be 100644 --- a/net/batman-adv/netlink.c +++ b/net/batman-adv/netlink.c @@ -1267,7 +1267,8 @@ batadv_get_vlan_from_info(struct batadv_priv *bat_priv, struct net *net, * * Return: 0 on success or negative error number in case of failure */ -static int batadv_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, +static int batadv_pre_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); @@ -1332,7 +1333,8 @@ err_put_softif: * @skb: Netlink message with request data * @info: receiver information */ -static void batadv_post_doit(const struct genl_ops *ops, struct sk_buff *skb, +static void batadv_post_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { struct batadv_hard_iface *hard_iface; diff --git a/net/batman-adv/network-coding.c b/net/batman-adv/network-coding.c index 5f4aeeb60dc4..bf29fba4dde5 100644 --- a/net/batman-adv/network-coding.c +++ b/net/batman-adv/network-coding.c @@ -1009,7 +1009,7 @@ static struct batadv_nc_path *batadv_nc_get_path(struct batadv_priv *bat_priv, static u8 batadv_nc_random_weight_tq(u8 tq) { /* randomize the estimated packet loss (max TQ - estimated TQ) */ - u8 rand_tq = prandom_u32_max(BATADV_TQ_MAX_VALUE + 1 - tq); + u8 rand_tq = get_random_u32_below(BATADV_TQ_MAX_VALUE + 1 - tq); /* convert to (randomized) estimated tq again */ return BATADV_TQ_MAX_VALUE - rand_tq; diff --git a/net/bluetooth/6lowpan.c b/net/bluetooth/6lowpan.c index 215af9b3b589..4eb1b3ced0d2 100644 --- a/net/bluetooth/6lowpan.c +++ b/net/bluetooth/6lowpan.c @@ -441,7 +441,7 @@ static int send_pkt(struct l2cap_chan *chan, struct sk_buff *skb, iv.iov_len = skb->len; memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE, &iv, 1, skb->len); + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iv, 1, skb->len); err = l2cap_chan_send(chan, &msg, skb->len); if (err > 0) { @@ -972,6 +972,7 @@ static int get_l2cap_conn(char *buf, bdaddr_t *addr, u8 *addr_type, hci_dev_lock(hdev); hcon = hci_conn_hash_lookup_le(hdev, addr, *addr_type); hci_dev_unlock(hdev); + hci_dev_put(hdev); if (!hcon) return -ENOENT; diff --git a/net/bluetooth/Kconfig b/net/bluetooth/Kconfig index ae3bdc6dfc92..da7cac0a1b71 100644 --- a/net/bluetooth/Kconfig +++ b/net/bluetooth/Kconfig @@ -78,6 +78,17 @@ config BT_LE Bluetooth Low Energy includes support low-energy physical layer available with Bluetooth version 4.0 or later. +config BT_LE_L2CAP_ECRED + bool "Bluetooth L2CAP Enhanced Credit Flow Control" + depends on BT_LE + default y + help + Bluetooth Low Energy L2CAP Enhanced Credit Flow Control available with + Bluetooth version 5.2 or later. + + This can be overridden by passing bluetooth.enable_ecred=[1|0] + on the kernel commandline. + config BT_6LOWPAN tristate "Bluetooth 6LoWPAN support" depends on BT_LE && 6LOWPAN diff --git a/net/bluetooth/a2mp.c b/net/bluetooth/a2mp.c index 1fcc482397c3..e7adb8a98cf9 100644 --- a/net/bluetooth/a2mp.c +++ b/net/bluetooth/a2mp.c @@ -56,7 +56,7 @@ static void a2mp_send(struct amp_mgr *mgr, u8 code, u8 ident, u16 len, void *dat memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE, &iv, 1, total_len); + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iv, 1, total_len); l2cap_chan_send(chan, &msg, total_len); diff --git a/net/bluetooth/af_bluetooth.c b/net/bluetooth/af_bluetooth.c index dc65974f5adb..1c3c7ff5c3c6 100644 --- a/net/bluetooth/af_bluetooth.c +++ b/net/bluetooth/af_bluetooth.c @@ -737,7 +737,7 @@ static int __init bt_init(void) err = bt_sysfs_init(); if (err < 0) - return err; + goto cleanup_led; err = sock_register(&bt_sock_family_ops); if (err) @@ -773,6 +773,8 @@ unregister_socket: sock_unregister(PF_BLUETOOTH); cleanup_sysfs: bt_sysfs_cleanup(); +cleanup_led: + bt_leds_cleanup(); return err; } diff --git a/net/bluetooth/hci_codec.c b/net/bluetooth/hci_codec.c index 38201532f58e..3cc135bb1d30 100644 --- a/net/bluetooth/hci_codec.c +++ b/net/bluetooth/hci_codec.c @@ -72,9 +72,8 @@ static void hci_read_codec_capabilities(struct hci_dev *hdev, __u8 transport, continue; } - skb = __hci_cmd_sync(hdev, HCI_OP_READ_LOCAL_CODEC_CAPS, - sizeof(*cmd), cmd, - HCI_CMD_TIMEOUT); + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODEC_CAPS, + sizeof(*cmd), cmd, 0, HCI_CMD_TIMEOUT, NULL); if (IS_ERR(skb)) { bt_dev_err(hdev, "Failed to read codec capabilities (%ld)", PTR_ERR(skb)); @@ -127,8 +126,8 @@ void hci_read_supported_codecs(struct hci_dev *hdev) struct hci_op_read_local_codec_caps caps; __u8 i; - skb = __hci_cmd_sync(hdev, HCI_OP_READ_LOCAL_CODECS, 0, NULL, - HCI_CMD_TIMEOUT); + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODECS, 0, NULL, + 0, HCI_CMD_TIMEOUT, NULL); if (IS_ERR(skb)) { bt_dev_err(hdev, "Failed to read local supported codecs (%ld)", @@ -158,7 +157,8 @@ void hci_read_supported_codecs(struct hci_dev *hdev) for (i = 0; i < std_codecs->num; i++) { caps.id = std_codecs->codec[i]; caps.direction = 0x00; - hci_read_codec_capabilities(hdev, LOCAL_CODEC_ACL_MASK, &caps); + hci_read_codec_capabilities(hdev, + LOCAL_CODEC_ACL_MASK | LOCAL_CODEC_SCO_MASK, &caps); } skb_pull(skb, flex_array_size(std_codecs, codec, std_codecs->num) @@ -178,7 +178,8 @@ void hci_read_supported_codecs(struct hci_dev *hdev) caps.cid = vnd_codecs->codec[i].cid; caps.vid = vnd_codecs->codec[i].vid; caps.direction = 0x00; - hci_read_codec_capabilities(hdev, LOCAL_CODEC_ACL_MASK, &caps); + hci_read_codec_capabilities(hdev, + LOCAL_CODEC_ACL_MASK | LOCAL_CODEC_SCO_MASK, &caps); } error: @@ -194,8 +195,8 @@ void hci_read_supported_codecs_v2(struct hci_dev *hdev) struct hci_op_read_local_codec_caps caps; __u8 i; - skb = __hci_cmd_sync(hdev, HCI_OP_READ_LOCAL_CODECS_V2, 0, NULL, - HCI_CMD_TIMEOUT); + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODECS_V2, 0, NULL, + 0, HCI_CMD_TIMEOUT, NULL); if (IS_ERR(skb)) { bt_dev_err(hdev, "Failed to read local supported codecs (%ld)", diff --git a/net/bluetooth/hci_conn.c b/net/bluetooth/hci_conn.c index a6c12863a253..acf563fbdfd9 100644 --- a/net/bluetooth/hci_conn.c +++ b/net/bluetooth/hci_conn.c @@ -821,19 +821,23 @@ static void terminate_big_destroy(struct hci_dev *hdev, void *data, int err) static int hci_le_terminate_big(struct hci_dev *hdev, u8 big, u8 bis) { struct iso_list_data *d; + int ret; bt_dev_dbg(hdev, "big 0x%2.2x bis 0x%2.2x", big, bis); - d = kmalloc(sizeof(*d), GFP_KERNEL); + d = kzalloc(sizeof(*d), GFP_KERNEL); if (!d) return -ENOMEM; - memset(d, 0, sizeof(*d)); d->big = big; d->bis = bis; - return hci_cmd_sync_queue(hdev, terminate_big_sync, d, - terminate_big_destroy); + ret = hci_cmd_sync_queue(hdev, terminate_big_sync, d, + terminate_big_destroy); + if (ret) + kfree(d); + + return ret; } static int big_terminate_sync(struct hci_dev *hdev, void *data) @@ -858,19 +862,23 @@ static int big_terminate_sync(struct hci_dev *hdev, void *data) static int hci_le_big_terminate(struct hci_dev *hdev, u8 big, u16 sync_handle) { struct iso_list_data *d; + int ret; bt_dev_dbg(hdev, "big 0x%2.2x sync_handle 0x%4.4x", big, sync_handle); - d = kmalloc(sizeof(*d), GFP_KERNEL); + d = kzalloc(sizeof(*d), GFP_KERNEL); if (!d) return -ENOMEM; - memset(d, 0, sizeof(*d)); d->big = big; d->sync_handle = sync_handle; - return hci_cmd_sync_queue(hdev, big_terminate_sync, d, - terminate_big_destroy); + ret = hci_cmd_sync_queue(hdev, big_terminate_sync, d, + terminate_big_destroy); + if (ret) + kfree(d); + + return ret; } /* Cleanup BIS connection @@ -1881,7 +1889,7 @@ static int hci_create_cis_sync(struct hci_dev *hdev, void *data) continue; /* Check if all CIS(s) belonging to a CIG are ready */ - if (conn->link->state != BT_CONNECTED || + if (!conn->link || conn->link->state != BT_CONNECTED || conn->state != BT_CONNECT) { cmd.cp.num_cis = 0; break; @@ -2046,19 +2054,12 @@ int hci_pa_create_sync(struct hci_dev *hdev, bdaddr_t *dst, __u8 dst_type, if (hci_dev_test_and_set_flag(hdev, HCI_PA_SYNC)) return -EBUSY; - cp = kmalloc(sizeof(*cp), GFP_KERNEL); + cp = kzalloc(sizeof(*cp), GFP_KERNEL); if (!cp) { hci_dev_clear_flag(hdev, HCI_PA_SYNC); return -ENOMEM; } - /* Convert from ISO socket address type to HCI address type */ - if (dst_type == BDADDR_LE_PUBLIC) - dst_type = ADDR_LE_DEV_PUBLIC; - else - dst_type = ADDR_LE_DEV_RANDOM; - - memset(cp, 0, sizeof(*cp)); cp->sid = sid; cp->addr_type = dst_type; bacpy(&cp->addr, dst); diff --git a/net/bluetooth/hci_core.c b/net/bluetooth/hci_core.c index 0540555b3704..b65c3aabcd53 100644 --- a/net/bluetooth/hci_core.c +++ b/net/bluetooth/hci_core.c @@ -2660,7 +2660,7 @@ int hci_register_dev(struct hci_dev *hdev) error = hci_register_suspend_notifier(hdev); if (error) - goto err_wqueue; + BT_WARN("register suspend notifier failed error:%d\n", error); queue_work(hdev->req_workqueue, &hdev->power_on); @@ -2764,7 +2764,8 @@ int hci_register_suspend_notifier(struct hci_dev *hdev) { int ret = 0; - if (!test_bit(HCI_QUIRK_NO_SUSPEND_NOTIFIER, &hdev->quirks)) { + if (!hdev->suspend_notifier.notifier_call && + !test_bit(HCI_QUIRK_NO_SUSPEND_NOTIFIER, &hdev->quirks)) { hdev->suspend_notifier.notifier_call = hci_suspend_notifier; ret = register_pm_notifier(&hdev->suspend_notifier); } @@ -2776,8 +2777,11 @@ int hci_unregister_suspend_notifier(struct hci_dev *hdev) { int ret = 0; - if (!test_bit(HCI_QUIRK_NO_SUSPEND_NOTIFIER, &hdev->quirks)) + if (hdev->suspend_notifier.notifier_call) { ret = unregister_pm_notifier(&hdev->suspend_notifier); + if (!ret) + hdev->suspend_notifier.notifier_call = NULL; + } return ret; } @@ -3981,7 +3985,7 @@ void hci_req_cmd_complete(struct hci_dev *hdev, u16 opcode, u8 status, *req_complete_skb = bt_cb(skb)->hci.req_complete_skb; else *req_complete = bt_cb(skb)->hci.req_complete; - kfree_skb(skb); + dev_kfree_skb_irq(skb); } spin_unlock_irqrestore(&hdev->cmd_q.lock, flags); } diff --git a/net/bluetooth/hci_debugfs.c b/net/bluetooth/hci_debugfs.c index 3f401ec5bb0c..b7f682922a16 100644 --- a/net/bluetooth/hci_debugfs.c +++ b/net/bluetooth/hci_debugfs.c @@ -757,7 +757,7 @@ static ssize_t force_static_address_write(struct file *file, bool enable; int err; - if (test_bit(HCI_UP, &hdev->flags)) + if (hdev_is_powered(hdev)) return -EBUSY; err = kstrtobool_from_user(user_buf, count, &enable); diff --git a/net/bluetooth/hci_event.c b/net/bluetooth/hci_event.c index faca701bce2a..ad92a4be5851 100644 --- a/net/bluetooth/hci_event.c +++ b/net/bluetooth/hci_event.c @@ -801,9 +801,6 @@ static u8 hci_cc_write_auth_payload_timeout(struct hci_dev *hdev, void *data, bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); - if (rp->status) - return rp->status; - sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_AUTH_PAYLOAD_TO); if (!sent) return rp->status; @@ -811,9 +808,17 @@ static u8 hci_cc_write_auth_payload_timeout(struct hci_dev *hdev, void *data, hci_dev_lock(hdev); conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); - if (conn) + if (!conn) { + rp->status = 0xff; + goto unlock; + } + + if (!rp->status) conn->auth_payload_timeout = get_unaligned_le16(sent + 2); + hci_encrypt_cfm(conn, 0); + +unlock: hci_dev_unlock(hdev); return rp->status; @@ -3680,8 +3685,13 @@ static void hci_encrypt_change_evt(struct hci_dev *hdev, void *data, cp.handle = cpu_to_le16(conn->handle); cp.timeout = cpu_to_le16(hdev->auth_payload_timeout); - hci_send_cmd(conn->hdev, HCI_OP_WRITE_AUTH_PAYLOAD_TO, - sizeof(cp), &cp); + if (hci_send_cmd(conn->hdev, HCI_OP_WRITE_AUTH_PAYLOAD_TO, + sizeof(cp), &cp)) { + bt_dev_err(hdev, "write auth payload timeout failed"); + goto notify; + } + + goto unlock; } notify: @@ -3838,8 +3848,11 @@ static u8 hci_cc_le_set_cig_params(struct hci_dev *hdev, void *data, conn->handle, conn->link); /* Create CIS if LE is already connected */ - if (conn->link && conn->link->state == BT_CONNECTED) + if (conn->link && conn->link->state == BT_CONNECTED) { + rcu_read_unlock(); hci_le_create_cis(conn->link); + rcu_read_lock(); + } if (i == rp->num_handles) break; @@ -6494,7 +6507,7 @@ static void hci_le_ext_adv_report_evt(struct hci_dev *hdev, void *data, info->length)) break; - evt_type = __le16_to_cpu(info->type); + evt_type = __le16_to_cpu(info->type) & LE_EXT_ADV_EVT_TYPE_MASK; legacy_evt_type = ext_evt_type_to_legacy(hdev, evt_type); if (legacy_evt_type != LE_ADV_INVALID) { process_adv_report(hdev, legacy_evt_type, &info->bdaddr, diff --git a/net/bluetooth/hci_request.c b/net/bluetooth/hci_request.c index 5a0296a4352e..f7e006a36382 100644 --- a/net/bluetooth/hci_request.c +++ b/net/bluetooth/hci_request.c @@ -269,7 +269,7 @@ void hci_req_add_ev(struct hci_request *req, u16 opcode, u32 plen, void hci_req_add(struct hci_request *req, u16 opcode, u32 plen, const void *param) { - bt_dev_err(req->hdev, "HCI_REQ-0x%4.4x", opcode); + bt_dev_dbg(req->hdev, "HCI_REQ-0x%4.4x", opcode); hci_req_add_ev(req, opcode, plen, param, 0); } diff --git a/net/bluetooth/hci_sync.c b/net/bluetooth/hci_sync.c index 76c3107c9f91..117eedb6f709 100644 --- a/net/bluetooth/hci_sync.c +++ b/net/bluetooth/hci_sync.c @@ -12,6 +12,7 @@ #include <net/bluetooth/mgmt.h> #include "hci_request.h" +#include "hci_codec.h" #include "hci_debugfs.h" #include "smp.h" #include "eir.h" @@ -3054,6 +3055,7 @@ int hci_update_name_sync(struct hci_dev *hdev) * Enable Authentication * lmp_bredr_capable(Set Fast Connectable -> Set Scan Type -> Set Class -> * Set Name -> Set EIR) + * HCI_FORCE_STATIC_ADDR | BDADDR_ANY && !HCI_BREDR_ENABLED (Set Static Address) */ int hci_powered_update_sync(struct hci_dev *hdev) { @@ -3093,6 +3095,23 @@ int hci_powered_update_sync(struct hci_dev *hdev) hci_update_eir_sync(hdev); } + /* If forcing static address is in use or there is no public + * address use the static address as random address (but skip + * the HCI command if the current random address is already the + * static one. + * + * In case BR/EDR has been disabled on a dual-mode controller + * and a static address has been configured, then use that + * address instead of the public BR/EDR address. + */ + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + (!bacmp(&hdev->bdaddr, BDADDR_ANY) && + !hci_dev_test_flag(hdev, HCI_BREDR_ENABLED))) { + if (bacmp(&hdev->static_addr, BDADDR_ANY)) + return hci_set_random_addr_sync(hdev, + &hdev->static_addr); + } + return 0; } @@ -3553,7 +3572,7 @@ static const struct hci_init_stage hci_init2[] = { static int hci_le_read_buffer_size_sync(struct hci_dev *hdev) { /* Use Read LE Buffer Size V2 if supported */ - if (hdev->commands[41] & 0x20) + if (iso_capable(hdev) && hdev->commands[41] & 0x20) return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_BUFFER_SIZE_V2, 0, NULL, HCI_CMD_TIMEOUT); @@ -3578,10 +3597,10 @@ static int hci_le_read_supported_states_sync(struct hci_dev *hdev) /* LE Controller init stage 2 command sequence */ static const struct hci_init_stage le_init2[] = { - /* HCI_OP_LE_READ_BUFFER_SIZE */ - HCI_INIT(hci_le_read_buffer_size_sync), /* HCI_OP_LE_READ_LOCAL_FEATURES */ HCI_INIT(hci_le_read_local_features_sync), + /* HCI_OP_LE_READ_BUFFER_SIZE */ + HCI_INIT(hci_le_read_buffer_size_sync), /* HCI_OP_LE_READ_SUPPORTED_STATES */ HCI_INIT(hci_le_read_supported_states_sync), {} @@ -3780,7 +3799,8 @@ static int hci_read_page_scan_activity_sync(struct hci_dev *hdev) static int hci_read_def_err_data_reporting_sync(struct hci_dev *hdev) { if (!(hdev->commands[18] & 0x04) || - !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING)) + !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING) || + test_bit(HCI_QUIRK_BROKEN_ERR_DATA_REPORTING, &hdev->quirks)) return 0; return __hci_cmd_sync_status(hdev, HCI_OP_READ_DEF_ERR_DATA_REPORTING, @@ -4238,11 +4258,12 @@ static int hci_set_event_mask_page_2_sync(struct hci_dev *hdev) /* Read local codec list if the HCI command is supported */ static int hci_read_local_codecs_sync(struct hci_dev *hdev) { - if (!(hdev->commands[29] & 0x20)) - return 0; + if (hdev->commands[45] & 0x04) + hci_read_supported_codecs_v2(hdev); + else if (hdev->commands[29] & 0x20) + hci_read_supported_codecs(hdev); - return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_CODECS, 0, NULL, - HCI_CMD_TIMEOUT); + return 0; } /* Read local pairing options if the HCI command is supported */ @@ -4258,7 +4279,7 @@ static int hci_read_local_pairing_opts_sync(struct hci_dev *hdev) /* Get MWS transport configuration if the HCI command is supported */ static int hci_get_mws_transport_config_sync(struct hci_dev *hdev) { - if (!(hdev->commands[30] & 0x08)) + if (!mws_transport_config_capable(hdev)) return 0; return __hci_cmd_sync_status(hdev, HCI_OP_GET_MWS_TRANSPORT_CONFIG, @@ -4298,7 +4319,8 @@ static int hci_set_err_data_report_sync(struct hci_dev *hdev) bool enabled = hci_dev_test_flag(hdev, HCI_WIDEBAND_SPEECH_ENABLED); if (!(hdev->commands[18] & 0x08) || - !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING)) + !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING) || + test_bit(HCI_QUIRK_BROKEN_ERR_DATA_REPORTING, &hdev->quirks)) return 0; if (enabled == hdev->err_data_reporting) @@ -4457,6 +4479,9 @@ static const struct { HCI_QUIRK_BROKEN(STORED_LINK_KEY, "HCI Delete Stored Link Key command is advertised, " "but not supported."), + HCI_QUIRK_BROKEN(ERR_DATA_REPORTING, + "HCI Read Default Erroneous Data Reporting command is " + "advertised, but not supported."), HCI_QUIRK_BROKEN(READ_TRANSMIT_POWER, "HCI Read Transmit Power Level command is advertised, " "but not supported."), @@ -4696,6 +4721,7 @@ int hci_dev_open_sync(struct hci_dev *hdev) hdev->flush(hdev); if (hdev->sent_cmd) { + cancel_delayed_work_sync(&hdev->cmd_timer); kfree_skb(hdev->sent_cmd); hdev->sent_cmd = NULL; } @@ -6161,20 +6187,13 @@ int hci_get_random_address(struct hci_dev *hdev, bool require_privacy, static int _update_adv_data_sync(struct hci_dev *hdev, void *data) { - u8 instance = *(u8 *)data; - - kfree(data); + u8 instance = PTR_ERR(data); return hci_update_adv_data_sync(hdev, instance); } int hci_update_adv_data(struct hci_dev *hdev, u8 instance) { - u8 *inst_ptr = kmalloc(1, GFP_KERNEL); - - if (!inst_ptr) - return -ENOMEM; - - *inst_ptr = instance; - return hci_cmd_sync_queue(hdev, _update_adv_data_sync, inst_ptr, NULL); + return hci_cmd_sync_queue(hdev, _update_adv_data_sync, + ERR_PTR(instance), NULL); } diff --git a/net/bluetooth/hidp/core.c b/net/bluetooth/hidp/core.c index cc20e706c639..bed1a7b9205c 100644 --- a/net/bluetooth/hidp/core.c +++ b/net/bluetooth/hidp/core.c @@ -739,7 +739,7 @@ static void hidp_stop(struct hid_device *hid) hid->claimed = 0; } -struct hid_ll_driver hidp_hid_driver = { +static const struct hid_ll_driver hidp_hid_driver = { .parse = hidp_parse, .start = hidp_start, .stop = hidp_stop, @@ -748,7 +748,6 @@ struct hid_ll_driver hidp_hid_driver = { .raw_request = hidp_raw_request, .output_report = hidp_output_report, }; -EXPORT_SYMBOL_GPL(hidp_hid_driver); /* This function sets up the hid device. It does not add it to the HID system. That is done in hidp_add_connection(). */ diff --git a/net/bluetooth/iso.c b/net/bluetooth/iso.c index f825857db6d0..24444b502e58 100644 --- a/net/bluetooth/iso.c +++ b/net/bluetooth/iso.c @@ -261,36 +261,42 @@ static int iso_connect_bis(struct sock *sk) if (!bis_capable(hdev)) { err = -EOPNOTSUPP; - goto done; + goto unlock; } /* Fail if out PHYs are marked as disabled */ if (!iso_pi(sk)->qos.out.phy) { err = -EINVAL; - goto done; + goto unlock; } - hcon = hci_connect_bis(hdev, &iso_pi(sk)->dst, iso_pi(sk)->dst_type, + hcon = hci_connect_bis(hdev, &iso_pi(sk)->dst, + le_addr_type(iso_pi(sk)->dst_type), &iso_pi(sk)->qos, iso_pi(sk)->base_len, iso_pi(sk)->base); if (IS_ERR(hcon)) { err = PTR_ERR(hcon); - goto done; + goto unlock; } conn = iso_conn_add(hcon); if (!conn) { hci_conn_drop(hcon); err = -ENOMEM; - goto done; + goto unlock; } - /* Update source addr of the socket */ - bacpy(&iso_pi(sk)->src, &hcon->src); + hci_dev_unlock(hdev); + hci_dev_put(hdev); err = iso_chan_add(conn, sk, NULL); if (err) - goto done; + return err; + + lock_sock(sk); + + /* Update source addr of the socket */ + bacpy(&iso_pi(sk)->src, &hcon->src); if (hcon->state == BT_CONNECTED) { iso_sock_clear_timer(sk); @@ -300,7 +306,10 @@ static int iso_connect_bis(struct sock *sk) iso_sock_set_timer(sk, sk->sk_sndtimeo); } -done: + release_sock(sk); + return err; + +unlock: hci_dev_unlock(hdev); hci_dev_put(hdev); return err; @@ -324,13 +333,13 @@ static int iso_connect_cis(struct sock *sk) if (!cis_central_capable(hdev)) { err = -EOPNOTSUPP; - goto done; + goto unlock; } /* Fail if either PHYs are marked as disabled */ if (!iso_pi(sk)->qos.in.phy && !iso_pi(sk)->qos.out.phy) { err = -EINVAL; - goto done; + goto unlock; } /* Just bind if DEFER_SETUP has been set */ @@ -340,7 +349,7 @@ static int iso_connect_cis(struct sock *sk) &iso_pi(sk)->qos); if (IS_ERR(hcon)) { err = PTR_ERR(hcon); - goto done; + goto unlock; } } else { hcon = hci_connect_cis(hdev, &iso_pi(sk)->dst, @@ -348,7 +357,7 @@ static int iso_connect_cis(struct sock *sk) &iso_pi(sk)->qos); if (IS_ERR(hcon)) { err = PTR_ERR(hcon); - goto done; + goto unlock; } } @@ -356,15 +365,20 @@ static int iso_connect_cis(struct sock *sk) if (!conn) { hci_conn_drop(hcon); err = -ENOMEM; - goto done; + goto unlock; } - /* Update source addr of the socket */ - bacpy(&iso_pi(sk)->src, &hcon->src); + hci_dev_unlock(hdev); + hci_dev_put(hdev); err = iso_chan_add(conn, sk, NULL); if (err) - goto done; + return err; + + lock_sock(sk); + + /* Update source addr of the socket */ + bacpy(&iso_pi(sk)->src, &hcon->src); if (hcon->state == BT_CONNECTED) { iso_sock_clear_timer(sk); @@ -377,7 +391,10 @@ static int iso_connect_cis(struct sock *sk) iso_sock_set_timer(sk, sk->sk_sndtimeo); } -done: + release_sock(sk); + return err; + +unlock: hci_dev_unlock(hdev); hci_dev_put(hdev); return err; @@ -831,20 +848,23 @@ static int iso_sock_connect(struct socket *sock, struct sockaddr *addr, bacpy(&iso_pi(sk)->dst, &sa->iso_bdaddr); iso_pi(sk)->dst_type = sa->iso_bdaddr_type; + release_sock(sk); + if (bacmp(&iso_pi(sk)->dst, BDADDR_ANY)) err = iso_connect_cis(sk); else err = iso_connect_bis(sk); if (err) - goto done; + return err; + + lock_sock(sk); if (!test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { err = bt_sock_wait_state(sk, BT_CONNECTED, sock_sndtimeo(sk, flags & O_NONBLOCK)); } -done: release_sock(sk); return err; } @@ -873,12 +893,11 @@ static int iso_listen_bis(struct sock *sk) if (!hdev) return -EHOSTUNREACH; - hci_dev_lock(hdev); - - err = hci_pa_create_sync(hdev, &iso_pi(sk)->dst, iso_pi(sk)->dst_type, + err = hci_pa_create_sync(hdev, &iso_pi(sk)->dst, + le_addr_type(iso_pi(sk)->dst_type), iso_pi(sk)->bc_sid); - hci_dev_unlock(hdev); + hci_dev_put(hdev); return err; } @@ -1098,28 +1117,22 @@ static int iso_sock_recvmsg(struct socket *sock, struct msghdr *msg, { struct sock *sk = sock->sk; struct iso_pinfo *pi = iso_pi(sk); - int err; BT_DBG("sk %p", sk); - lock_sock(sk); - if (test_and_clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { switch (sk->sk_state) { case BT_CONNECT2: + lock_sock(sk); iso_conn_defer_accept(pi->conn->hcon); sk->sk_state = BT_CONFIG; release_sock(sk); return 0; case BT_CONNECT: - err = iso_connect_cis(sk); - release_sock(sk); - return err; + return iso_connect_cis(sk); } } - release_sock(sk); - return bt_sock_recvmsg(sock, msg, len, flags); } @@ -1414,33 +1427,29 @@ static void iso_conn_ready(struct iso_conn *conn) struct sock *parent; struct sock *sk = conn->sk; struct hci_ev_le_big_sync_estabilished *ev; + struct hci_conn *hcon; BT_DBG("conn %p", conn); if (sk) { iso_sock_ready(conn->sk); } else { - iso_conn_lock(conn); - - if (!conn->hcon) { - iso_conn_unlock(conn); + hcon = conn->hcon; + if (!hcon) return; - } - ev = hci_recv_event_data(conn->hcon->hdev, + ev = hci_recv_event_data(hcon->hdev, HCI_EVT_LE_BIG_SYNC_ESTABILISHED); if (ev) - parent = iso_get_sock_listen(&conn->hcon->src, - &conn->hcon->dst, + parent = iso_get_sock_listen(&hcon->src, + &hcon->dst, iso_match_big, ev); else - parent = iso_get_sock_listen(&conn->hcon->src, + parent = iso_get_sock_listen(&hcon->src, BDADDR_ANY, NULL, NULL); - if (!parent) { - iso_conn_unlock(conn); + if (!parent) return; - } lock_sock(parent); @@ -1448,30 +1457,29 @@ static void iso_conn_ready(struct iso_conn *conn) BTPROTO_ISO, GFP_ATOMIC, 0); if (!sk) { release_sock(parent); - iso_conn_unlock(conn); return; } iso_sock_init(sk, parent); - bacpy(&iso_pi(sk)->src, &conn->hcon->src); - iso_pi(sk)->src_type = conn->hcon->src_type; + bacpy(&iso_pi(sk)->src, &hcon->src); + iso_pi(sk)->src_type = hcon->src_type; /* If hcon has no destination address (BDADDR_ANY) it means it * was created by HCI_EV_LE_BIG_SYNC_ESTABILISHED so we need to * initialize using the parent socket destination address. */ - if (!bacmp(&conn->hcon->dst, BDADDR_ANY)) { - bacpy(&conn->hcon->dst, &iso_pi(parent)->dst); - conn->hcon->dst_type = iso_pi(parent)->dst_type; - conn->hcon->sync_handle = iso_pi(parent)->sync_handle; + if (!bacmp(&hcon->dst, BDADDR_ANY)) { + bacpy(&hcon->dst, &iso_pi(parent)->dst); + hcon->dst_type = iso_pi(parent)->dst_type; + hcon->sync_handle = iso_pi(parent)->sync_handle; } - bacpy(&iso_pi(sk)->dst, &conn->hcon->dst); - iso_pi(sk)->dst_type = conn->hcon->dst_type; + bacpy(&iso_pi(sk)->dst, &hcon->dst); + iso_pi(sk)->dst_type = hcon->dst_type; - hci_conn_hold(conn->hcon); - __iso_chan_add(conn, sk, parent); + hci_conn_hold(hcon); + iso_chan_add(conn, sk, parent); if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags)) sk->sk_state = BT_CONNECT2; @@ -1482,8 +1490,6 @@ static void iso_conn_ready(struct iso_conn *conn) parent->sk_data_ready(parent); release_sock(parent); - - iso_conn_unlock(conn); } } diff --git a/net/bluetooth/l2cap_core.c b/net/bluetooth/l2cap_core.c index 9c24947aa41e..a3e0dc6a6e73 100644 --- a/net/bluetooth/l2cap_core.c +++ b/net/bluetooth/l2cap_core.c @@ -45,7 +45,7 @@ #define LE_FLOWCTL_MAX_CREDITS 65535 bool disable_ertm; -bool enable_ecred; +bool enable_ecred = IS_ENABLED(CONFIG_BT_LE_L2CAP_ECRED); static u32 l2cap_feat_mask = L2CAP_FEAT_FIXED_CHAN | L2CAP_FEAT_UCD; @@ -4453,7 +4453,8 @@ static inline int l2cap_config_req(struct l2cap_conn *conn, chan->ident = cmd->ident; l2cap_send_cmd(conn, cmd->ident, L2CAP_CONF_RSP, len, rsp); - chan->num_conf_rsp++; + if (chan->num_conf_rsp < L2CAP_CONF_MAX_CONF_RSP) + chan->num_conf_rsp++; /* Reset config buffer. */ chan->conf_len = 0; diff --git a/net/bluetooth/lib.c b/net/bluetooth/lib.c index 469a0c95b6e8..53a796ac078c 100644 --- a/net/bluetooth/lib.c +++ b/net/bluetooth/lib.c @@ -170,7 +170,7 @@ __u8 bt_status(int err) case -EMLINK: return 0x09; - case EALREADY: + case -EALREADY: return 0x0b; case -EBUSY: @@ -191,7 +191,7 @@ __u8 bt_status(int err) case -ECONNABORTED: return 0x16; - case ELOOP: + case -ELOOP: return 0x17; case -EPROTONOSUPPORT: diff --git a/net/bluetooth/mgmt.c b/net/bluetooth/mgmt.c index a92e7e485feb..d2ea8e19aa1b 100644 --- a/net/bluetooth/mgmt.c +++ b/net/bluetooth/mgmt.c @@ -7373,9 +7373,8 @@ static int get_conn_info(struct sock *sk, struct hci_dev *hdev, void *data, /* To avoid client trying to guess when to poll again for information we * calculate conn info age as random value between min/max set in hdev. */ - conn_info_age = hdev->conn_info_min_age + - prandom_u32_max(hdev->conn_info_max_age - - hdev->conn_info_min_age); + conn_info_age = get_random_u32_inclusive(hdev->conn_info_min_age, + hdev->conn_info_max_age - 1); /* Query controller to refresh cached values if they are too old or were * never read. @@ -8859,7 +8858,7 @@ static int add_ext_adv_params(struct sock *sk, struct hci_dev *hdev, * extra parameters we don't know about will be ignored in this request. */ if (data_len < MGMT_ADD_EXT_ADV_PARAMS_MIN_SIZE) - return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, MGMT_STATUS_INVALID_PARAMS); flags = __le32_to_cpu(cp->flags); diff --git a/net/bluetooth/mgmt_util.h b/net/bluetooth/mgmt_util.h index 6a8b7e84293d..bdf978605d5a 100644 --- a/net/bluetooth/mgmt_util.h +++ b/net/bluetooth/mgmt_util.h @@ -27,7 +27,7 @@ struct mgmt_mesh_tx { struct sock *sk; u8 handle; u8 instance; - u8 param[sizeof(struct mgmt_cp_mesh_send) + 29]; + u8 param[sizeof(struct mgmt_cp_mesh_send) + 31]; }; struct mgmt_pending_cmd { diff --git a/net/bluetooth/rfcomm/core.c b/net/bluetooth/rfcomm/core.c index 7324764384b6..8d6fce9005bd 100644 --- a/net/bluetooth/rfcomm/core.c +++ b/net/bluetooth/rfcomm/core.c @@ -590,7 +590,7 @@ int rfcomm_dlc_send(struct rfcomm_dlc *d, struct sk_buff *skb) ret = rfcomm_dlc_send_frag(d, frag); if (ret < 0) { - kfree_skb(frag); + dev_kfree_skb_irq(frag); goto unlock; } diff --git a/net/bluetooth/rfcomm/sock.c b/net/bluetooth/rfcomm/sock.c index 21e24da4847f..4397e14ff560 100644 --- a/net/bluetooth/rfcomm/sock.c +++ b/net/bluetooth/rfcomm/sock.c @@ -391,6 +391,7 @@ static int rfcomm_sock_connect(struct socket *sock, struct sockaddr *addr, int a addr->sa_family != AF_BLUETOOTH) return -EINVAL; + sock_hold(sk); lock_sock(sk); if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND) { @@ -410,14 +411,18 @@ static int rfcomm_sock_connect(struct socket *sock, struct sockaddr *addr, int a d->sec_level = rfcomm_pi(sk)->sec_level; d->role_switch = rfcomm_pi(sk)->role_switch; + /* Drop sock lock to avoid potential deadlock with the RFCOMM lock */ + release_sock(sk); err = rfcomm_dlc_open(d, &rfcomm_pi(sk)->src, &sa->rc_bdaddr, sa->rc_channel); - if (!err) + lock_sock(sk); + if (!err && !sock_flag(sk, SOCK_ZAPPED)) err = bt_sock_wait_state(sk, BT_CONNECTED, sock_sndtimeo(sk, flags & O_NONBLOCK)); done: release_sock(sk); + sock_put(sk); return err; } diff --git a/net/bluetooth/smp.c b/net/bluetooth/smp.c index 11f853d0500f..70663229b3cc 100644 --- a/net/bluetooth/smp.c +++ b/net/bluetooth/smp.c @@ -605,7 +605,7 @@ static void smp_send_cmd(struct l2cap_conn *conn, u8 code, u16 len, void *data) memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE, iv, 2, 1 + len); + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, iv, 2, 1 + len); l2cap_chan_send(chan, &msg, 1 + len); diff --git a/net/bpf/bpf_dummy_struct_ops.c b/net/bpf/bpf_dummy_struct_ops.c index e78dadfc5829..1ac4467928a9 100644 --- a/net/bpf/bpf_dummy_struct_ops.c +++ b/net/bpf/bpf_dummy_struct_ops.c @@ -124,8 +124,7 @@ int bpf_struct_ops_test_run(struct bpf_prog *prog, const union bpf_attr *kattr, if (err < 0) goto out; - set_memory_ro((long)image, 1); - set_memory_x((long)image, 1); + set_memory_rox((long)image, 1); prog_ret = dummy_ops_call_op(image, args); err = dummy_ops_copy_args(args); @@ -156,29 +155,29 @@ static bool bpf_dummy_ops_is_valid_access(int off, int size, } static int bpf_dummy_ops_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, u32 *next_btf_id, enum bpf_type_flag *flag) { const struct btf_type *state; + const struct btf_type *t; s32 type_id; int err; - type_id = btf_find_by_name_kind(btf, "bpf_dummy_ops_state", + type_id = btf_find_by_name_kind(reg->btf, "bpf_dummy_ops_state", BTF_KIND_STRUCT); if (type_id < 0) return -EINVAL; - state = btf_type_by_id(btf, type_id); + t = btf_type_by_id(reg->btf, reg->btf_id); + state = btf_type_by_id(reg->btf, type_id); if (t != state) { bpf_log(log, "only access to bpf_dummy_ops_state is supported\n"); return -EACCES; } - err = btf_struct_access(log, btf, t, off, size, atype, next_btf_id, - flag); + err = btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); if (err < 0) return err; diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index 5079fb089b5f..2723623429ac 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -781,6 +781,7 @@ static void *bpf_test_init(const union bpf_attr *kattr, u32 user_size, if (user_size > size) return ERR_PTR(-EMSGSIZE); + size = SKB_DATA_ALIGN(size); data = kzalloc(size + headroom + tailroom, GFP_USER); if (!data) return ERR_PTR(-ENOMEM); @@ -986,9 +987,6 @@ static int convert___skb_to_skb(struct sk_buff *skb, struct __sk_buff *__skb) { struct qdisc_skb_cb *cb = (struct qdisc_skb_cb *)skb->cb; - if (!skb->len) - return -EINVAL; - if (!__skb) return 0; @@ -1137,7 +1135,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, } sock_init_data(NULL, sk); - skb = build_skb(data, 0); + skb = slab_build_skb(data); if (!skb) { kfree(data); kfree(ctx); diff --git a/net/bridge/br.c b/net/bridge/br.c index 96e91d69a9a8..4f5098d33a46 100644 --- a/net/bridge/br.c +++ b/net/bridge/br.c @@ -166,13 +166,14 @@ static int br_switchdev_event(struct notifier_block *unused, case SWITCHDEV_FDB_ADD_TO_BRIDGE: fdb_info = ptr; err = br_fdb_external_learn_add(br, p, fdb_info->addr, - fdb_info->vid, false); + fdb_info->vid, + fdb_info->locked, false); if (err) { err = notifier_from_errno(err); break; } br_fdb_offloaded_set(br, p, fdb_info->addr, - fdb_info->vid, true); + fdb_info->vid, fdb_info->offloaded); break; case SWITCHDEV_FDB_DEL_TO_BRIDGE: fdb_info = ptr; diff --git a/net/bridge/br_fdb.c b/net/bridge/br_fdb.c index e7f4fccb6adb..e69a872bfc1d 100644 --- a/net/bridge/br_fdb.c +++ b/net/bridge/br_fdb.c @@ -105,6 +105,7 @@ static int fdb_fill_info(struct sk_buff *skb, const struct net_bridge *br, struct nda_cacheinfo ci; struct nlmsghdr *nlh; struct ndmsg *ndm; + u32 ext_flags = 0; nlh = nlmsg_put(skb, portid, seq, type, sizeof(*ndm), flags); if (nlh == NULL) @@ -125,11 +126,16 @@ static int fdb_fill_info(struct sk_buff *skb, const struct net_bridge *br, ndm->ndm_flags |= NTF_EXT_LEARNED; if (test_bit(BR_FDB_STICKY, &fdb->flags)) ndm->ndm_flags |= NTF_STICKY; + if (test_bit(BR_FDB_LOCKED, &fdb->flags)) + ext_flags |= NTF_EXT_LOCKED; if (nla_put(skb, NDA_LLADDR, ETH_ALEN, &fdb->key.addr)) goto nla_put_failure; if (nla_put_u32(skb, NDA_MASTER, br->dev->ifindex)) goto nla_put_failure; + if (nla_put_u32(skb, NDA_FLAGS_EXT, ext_flags)) + goto nla_put_failure; + ci.ndm_used = jiffies_to_clock_t(now - fdb->used); ci.ndm_confirmed = 0; ci.ndm_updated = jiffies_to_clock_t(now - fdb->updated); @@ -171,6 +177,7 @@ static inline size_t fdb_nlmsg_size(void) return NLMSG_ALIGN(sizeof(struct ndmsg)) + nla_total_size(ETH_ALEN) /* NDA_LLADDR */ + nla_total_size(sizeof(u32)) /* NDA_MASTER */ + + nla_total_size(sizeof(u32)) /* NDA_FLAGS_EXT */ + nla_total_size(sizeof(u16)) /* NDA_VLAN */ + nla_total_size(sizeof(struct nda_cacheinfo)) + nla_total_size(0) /* NDA_FDB_EXT_ATTRS */ @@ -879,6 +886,11 @@ void br_fdb_update(struct net_bridge *br, struct net_bridge_port *source, &fdb->flags))) clear_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags); + /* Clear locked flag when roaming to an + * unlocked port. + */ + if (unlikely(test_bit(BR_FDB_LOCKED, &fdb->flags))) + clear_bit(BR_FDB_LOCKED, &fdb->flags); } if (unlikely(test_bit(BR_FDB_ADDED_BY_USER, &flags))) @@ -1082,6 +1094,9 @@ static int fdb_add_entry(struct net_bridge *br, struct net_bridge_port *source, modified = true; } + if (test_and_clear_bit(BR_FDB_LOCKED, &fdb->flags)) + modified = true; + if (fdb_handle_notify(fdb, notify)) modified = true; @@ -1124,7 +1139,7 @@ static int __br_fdb_add(struct ndmsg *ndm, struct net_bridge *br, "FDB entry towards bridge must be permanent"); return -EINVAL; } - err = br_fdb_external_learn_add(br, p, addr, vid, true); + err = br_fdb_external_learn_add(br, p, addr, vid, false, true); } else { spin_lock_bh(&br->hash_lock); err = fdb_add_entry(br, p, addr, ndm, nlh_flags, vid, nfea_tb); @@ -1150,6 +1165,7 @@ int br_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], struct net_bridge_port *p = NULL; struct net_bridge_vlan *v; struct net_bridge *br = NULL; + u32 ext_flags = 0; int err = 0; trace_br_fdb_add(ndm, dev, addr, vid, nlh_flags); @@ -1178,6 +1194,14 @@ int br_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], vg = nbp_vlan_group(p); } + if (tb[NDA_FLAGS_EXT]) + ext_flags = nla_get_u32(tb[NDA_FLAGS_EXT]); + + if (ext_flags & NTF_EXT_LOCKED) { + NL_SET_ERR_MSG_MOD(extack, "Cannot add FDB entry with \"locked\" flag set"); + return -EINVAL; + } + if (tb[NDA_FDB_EXT_ATTRS]) { attr = tb[NDA_FDB_EXT_ATTRS]; err = nla_parse_nested(nfea_tb, NFEA_MAX, attr, @@ -1353,7 +1377,7 @@ void br_fdb_unsync_static(struct net_bridge *br, struct net_bridge_port *p) } int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, - const unsigned char *addr, u16 vid, + const unsigned char *addr, u16 vid, bool locked, bool swdev_notify) { struct net_bridge_fdb_entry *fdb; @@ -1362,6 +1386,9 @@ int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, trace_br_fdb_external_learn_add(br, p, addr, vid); + if (locked && (!p || !(p->flags & BR_PORT_MAB))) + return -EINVAL; + spin_lock_bh(&br->hash_lock); fdb = br_fdb_find(br, addr, vid); @@ -1374,6 +1401,9 @@ int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, if (!p) flags |= BIT(BR_FDB_LOCAL); + if (locked) + flags |= BIT(BR_FDB_LOCKED); + fdb = fdb_create(br, p, addr, vid, flags); if (!fdb) { err = -ENOMEM; @@ -1381,6 +1411,13 @@ int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, } fdb_notify(br, fdb, RTM_NEWNEIGH, swdev_notify); } else { + if (locked && + (!test_bit(BR_FDB_LOCKED, &fdb->flags) || + READ_ONCE(fdb->dst) != p)) { + err = -EINVAL; + goto err_unlock; + } + fdb->updated = jiffies; if (READ_ONCE(fdb->dst) != p) { @@ -1397,6 +1434,11 @@ int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, modified = true; } + if (locked != test_bit(BR_FDB_LOCKED, &fdb->flags)) { + change_bit(BR_FDB_LOCKED, &fdb->flags); + modified = true; + } + if (swdev_notify) set_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); diff --git a/net/bridge/br_if.c b/net/bridge/br_if.c index 228fd5b20f10..ad13b48e3e08 100644 --- a/net/bridge/br_if.c +++ b/net/bridge/br_if.c @@ -262,7 +262,7 @@ static void release_nbp(struct kobject *kobj) kfree(p); } -static void brport_get_ownership(struct kobject *kobj, kuid_t *uid, kgid_t *gid) +static void brport_get_ownership(const struct kobject *kobj, kuid_t *uid, kgid_t *gid) { struct net_bridge_port *p = kobj_to_brport(kobj); diff --git a/net/bridge/br_input.c b/net/bridge/br_input.c index 68b3e850bcb9..3027e8f6be15 100644 --- a/net/bridge/br_input.c +++ b/net/bridge/br_input.c @@ -109,9 +109,26 @@ int br_handle_frame_finish(struct net *net, struct sock *sk, struct sk_buff *skb struct net_bridge_fdb_entry *fdb_src = br_fdb_find_rcu(br, eth_hdr(skb)->h_source, vid); - if (!fdb_src || READ_ONCE(fdb_src->dst) != p || - test_bit(BR_FDB_LOCAL, &fdb_src->flags)) + if (!fdb_src) { + /* FDB miss. Create locked FDB entry if MAB is enabled + * and drop the packet. + */ + if (p->flags & BR_PORT_MAB) + br_fdb_update(br, p, eth_hdr(skb)->h_source, + vid, BIT(BR_FDB_LOCKED)); goto drop; + } else if (READ_ONCE(fdb_src->dst) != p || + test_bit(BR_FDB_LOCAL, &fdb_src->flags)) { + /* FDB mismatch. Drop the packet without roaming. */ + goto drop; + } else if (test_bit(BR_FDB_LOCKED, &fdb_src->flags)) { + /* FDB match, but entry is locked. Refresh it and drop + * the packet. + */ + br_fdb_update(br, p, eth_hdr(skb)->h_source, vid, + BIT(BR_FDB_LOCKED)); + goto drop; + } } nbp_switchdev_frame_mark(p, skb); diff --git a/net/bridge/br_mdb.c b/net/bridge/br_mdb.c index 589ff497d50c..00e5743647b0 100644 --- a/net/bridge/br_mdb.c +++ b/net/bridge/br_mdb.c @@ -663,6 +663,28 @@ errout: rtnl_set_sk_err(net, RTNLGRP_MDB, err); } +static const struct nla_policy +br_mdbe_src_list_entry_pol[MDBE_SRCATTR_MAX + 1] = { + [MDBE_SRCATTR_ADDRESS] = NLA_POLICY_RANGE(NLA_BINARY, + sizeof(struct in_addr), + sizeof(struct in6_addr)), +}; + +static const struct nla_policy +br_mdbe_src_list_pol[MDBE_SRC_LIST_MAX + 1] = { + [MDBE_SRC_LIST_ENTRY] = NLA_POLICY_NESTED(br_mdbe_src_list_entry_pol), +}; + +static const struct nla_policy br_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = { + [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY, + sizeof(struct in_addr), + sizeof(struct in6_addr)), + [MDBE_ATTR_GROUP_MODE] = NLA_POLICY_RANGE(NLA_U8, MCAST_EXCLUDE, + MCAST_INCLUDE), + [MDBE_ATTR_SRC_LIST] = NLA_POLICY_NESTED(br_mdbe_src_list_pol), + [MDBE_ATTR_RTPROT] = NLA_POLICY_MIN(NLA_U8, RTPROT_STATIC), +}; + static bool is_valid_mdb_entry(struct br_mdb_entry *entry, struct netlink_ext_ack *extack) { @@ -748,79 +770,6 @@ static bool is_valid_mdb_source(struct nlattr *attr, __be16 proto, return true; } -static const struct nla_policy br_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = { - [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY, - sizeof(struct in_addr), - sizeof(struct in6_addr)), -}; - -static int br_mdb_parse(struct sk_buff *skb, struct nlmsghdr *nlh, - struct net_device **pdev, struct br_mdb_entry **pentry, - struct nlattr **mdb_attrs, struct netlink_ext_ack *extack) -{ - struct net *net = sock_net(skb->sk); - struct br_mdb_entry *entry; - struct br_port_msg *bpm; - struct nlattr *tb[MDBA_SET_ENTRY_MAX+1]; - struct net_device *dev; - int err; - - err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb, - MDBA_SET_ENTRY_MAX, NULL, NULL); - if (err < 0) - return err; - - bpm = nlmsg_data(nlh); - if (bpm->ifindex == 0) { - NL_SET_ERR_MSG_MOD(extack, "Invalid bridge ifindex"); - return -EINVAL; - } - - dev = __dev_get_by_index(net, bpm->ifindex); - if (dev == NULL) { - NL_SET_ERR_MSG_MOD(extack, "Bridge device doesn't exist"); - return -ENODEV; - } - - if (!netif_is_bridge_master(dev)) { - NL_SET_ERR_MSG_MOD(extack, "Device is not a bridge"); - return -EOPNOTSUPP; - } - - *pdev = dev; - - if (!tb[MDBA_SET_ENTRY]) { - NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY attribute"); - return -EINVAL; - } - if (nla_len(tb[MDBA_SET_ENTRY]) != sizeof(struct br_mdb_entry)) { - NL_SET_ERR_MSG_MOD(extack, "Invalid MDBA_SET_ENTRY attribute length"); - return -EINVAL; - } - - entry = nla_data(tb[MDBA_SET_ENTRY]); - if (!is_valid_mdb_entry(entry, extack)) - return -EINVAL; - *pentry = entry; - - if (tb[MDBA_SET_ENTRY_ATTRS]) { - err = nla_parse_nested(mdb_attrs, MDBE_ATTR_MAX, - tb[MDBA_SET_ENTRY_ATTRS], - br_mdbe_attrs_pol, extack); - if (err) - return err; - if (mdb_attrs[MDBE_ATTR_SOURCE] && - !is_valid_mdb_source(mdb_attrs[MDBE_ATTR_SOURCE], - entry->addr.proto, extack)) - return -EINVAL; - } else { - memset(mdb_attrs, 0, - sizeof(struct nlattr *) * (MDBE_ATTR_MAX + 1)); - } - - return 0; -} - static struct net_bridge_mcast * __br_mdb_choose_context(struct net_bridge *br, const struct br_mdb_entry *entry, @@ -853,218 +802,669 @@ out: return brmctx; } -static int br_mdb_add_group(struct net_bridge *br, struct net_bridge_port *port, - struct br_mdb_entry *entry, - struct nlattr **mdb_attrs, - struct netlink_ext_ack *extack) +static int br_mdb_replace_group_sg(const struct br_mdb_config *cfg, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + struct net_bridge_mcast *brmctx, + unsigned char flags) +{ + unsigned long now = jiffies; + + pg->flags = flags; + pg->rt_protocol = cfg->rt_protocol; + if (!(flags & MDB_PG_FLAGS_PERMANENT) && !cfg->src_entry) + mod_timer(&pg->timer, + now + brmctx->multicast_membership_interval); + else + del_timer(&pg->timer); + + br_mdb_notify(cfg->br->dev, mp, pg, RTM_NEWMDB); + + return 0; +} + +static int br_mdb_add_group_sg(const struct br_mdb_config *cfg, + struct net_bridge_mdb_entry *mp, + struct net_bridge_mcast *brmctx, + unsigned char flags, + struct netlink_ext_ack *extack) { - struct net_bridge_mdb_entry *mp, *star_mp; struct net_bridge_port_group __rcu **pp; struct net_bridge_port_group *p; - struct net_bridge_mcast *brmctx; - struct br_ip group, star_group; unsigned long now = jiffies; - unsigned char flags = 0; - u8 filter_mode; - int err; - __mdb_entry_to_br_ip(entry, &group, mdb_attrs); + for (pp = &mp->ports; + (p = mlock_dereference(*pp, cfg->br)) != NULL; + pp = &p->next) { + if (p->key.port == cfg->p) { + if (!(cfg->nlflags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG_MOD(extack, "(S, G) group is already joined by port"); + return -EEXIST; + } + return br_mdb_replace_group_sg(cfg, mp, p, brmctx, + flags); + } + if ((unsigned long)p->key.port < (unsigned long)cfg->p) + break; + } - brmctx = __br_mdb_choose_context(br, entry, extack); - if (!brmctx) - return -EINVAL; + p = br_multicast_new_port_group(cfg->p, &cfg->group, *pp, flags, NULL, + MCAST_INCLUDE, cfg->rt_protocol); + if (unlikely(!p)) { + NL_SET_ERR_MSG_MOD(extack, "Couldn't allocate new (S, G) port group"); + return -ENOMEM; + } + rcu_assign_pointer(*pp, p); + if (!(flags & MDB_PG_FLAGS_PERMANENT) && !cfg->src_entry) + mod_timer(&p->timer, + now + brmctx->multicast_membership_interval); + br_mdb_notify(cfg->br->dev, mp, p, RTM_NEWMDB); - /* host join errors which can happen before creating the group */ - if (!port && !br_group_is_l2(&group)) { - /* don't allow any flags for host-joined IP groups */ - if (entry->state) { - NL_SET_ERR_MSG_MOD(extack, "Flags are not allowed for host groups"); - return -EINVAL; - } - if (!br_multicast_is_star_g(&group)) { - NL_SET_ERR_MSG_MOD(extack, "Groups with sources cannot be manually host joined"); - return -EINVAL; - } + /* All of (*, G) EXCLUDE ports need to be added to the new (S, G) for + * proper replication. + */ + if (br_multicast_should_handle_mode(brmctx, cfg->group.proto)) { + struct net_bridge_mdb_entry *star_mp; + struct br_ip star_group; + + star_group = p->key.addr; + memset(&star_group.src, 0, sizeof(star_group.src)); + star_mp = br_mdb_ip_get(cfg->br, &star_group); + if (star_mp) + br_multicast_sg_add_exclude_ports(star_mp, p); } - if (br_group_is_l2(&group) && entry->state != MDB_PERMANENT) { - NL_SET_ERR_MSG_MOD(extack, "Only permanent L2 entries allowed"); - return -EINVAL; + return 0; +} + +static int br_mdb_add_group_src_fwd(const struct br_mdb_config *cfg, + struct br_ip *src_ip, + struct net_bridge_mcast *brmctx, + struct netlink_ext_ack *extack) +{ + struct net_bridge_mdb_entry *sgmp; + struct br_mdb_config sg_cfg; + struct br_ip sg_ip; + u8 flags = 0; + + sg_ip = cfg->group; + sg_ip.src = src_ip->src; + sgmp = br_multicast_new_group(cfg->br, &sg_ip); + if (IS_ERR(sgmp)) { + NL_SET_ERR_MSG_MOD(extack, "Failed to add (S, G) MDB entry"); + return PTR_ERR(sgmp); } - mp = br_mdb_ip_get(br, &group); - if (!mp) { - mp = br_multicast_new_group(br, &group); - err = PTR_ERR_OR_ZERO(mp); + if (cfg->entry->state == MDB_PERMANENT) + flags |= MDB_PG_FLAGS_PERMANENT; + if (cfg->filter_mode == MCAST_EXCLUDE) + flags |= MDB_PG_FLAGS_BLOCKED; + + memset(&sg_cfg, 0, sizeof(sg_cfg)); + sg_cfg.br = cfg->br; + sg_cfg.p = cfg->p; + sg_cfg.entry = cfg->entry; + sg_cfg.group = sg_ip; + sg_cfg.src_entry = true; + sg_cfg.filter_mode = MCAST_INCLUDE; + sg_cfg.rt_protocol = cfg->rt_protocol; + sg_cfg.nlflags = cfg->nlflags; + return br_mdb_add_group_sg(&sg_cfg, sgmp, brmctx, flags, extack); +} + +static int br_mdb_add_group_src(const struct br_mdb_config *cfg, + struct net_bridge_port_group *pg, + struct net_bridge_mcast *brmctx, + struct br_mdb_src_entry *src, + struct netlink_ext_ack *extack) +{ + struct net_bridge_group_src *ent; + unsigned long now = jiffies; + int err; + + ent = br_multicast_find_group_src(pg, &src->addr); + if (!ent) { + ent = br_multicast_new_group_src(pg, &src->addr); + if (!ent) { + NL_SET_ERR_MSG_MOD(extack, "Failed to add new source entry"); + return -ENOSPC; + } + } else if (!(cfg->nlflags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG_MOD(extack, "Source entry already exists"); + return -EEXIST; + } + + if (cfg->filter_mode == MCAST_INCLUDE && + cfg->entry->state == MDB_TEMPORARY) + mod_timer(&ent->timer, now + br_multicast_gmi(brmctx)); + else + del_timer(&ent->timer); + + /* Install a (S, G) forwarding entry for the source. */ + err = br_mdb_add_group_src_fwd(cfg, &src->addr, brmctx, extack); + if (err) + goto err_del_sg; + + ent->flags = BR_SGRP_F_INSTALLED | BR_SGRP_F_USER_ADDED; + + return 0; + +err_del_sg: + __br_multicast_del_group_src(ent); + return err; +} + +static void br_mdb_del_group_src(struct net_bridge_port_group *pg, + struct br_mdb_src_entry *src) +{ + struct net_bridge_group_src *ent; + + ent = br_multicast_find_group_src(pg, &src->addr); + if (WARN_ON_ONCE(!ent)) + return; + br_multicast_del_group_src(ent, false); +} + +static int br_mdb_add_group_srcs(const struct br_mdb_config *cfg, + struct net_bridge_port_group *pg, + struct net_bridge_mcast *brmctx, + struct netlink_ext_ack *extack) +{ + int i, err; + + for (i = 0; i < cfg->num_src_entries; i++) { + err = br_mdb_add_group_src(cfg, pg, brmctx, + &cfg->src_entries[i], extack); if (err) - return err; + goto err_del_group_srcs; } - /* host join */ - if (!port) { - if (mp->host_joined) { - NL_SET_ERR_MSG_MOD(extack, "Group is already joined by host"); - return -EEXIST; - } + return 0; - br_multicast_host_join(brmctx, mp, false); - br_mdb_notify(br->dev, mp, NULL, RTM_NEWMDB); +err_del_group_srcs: + for (i--; i >= 0; i--) + br_mdb_del_group_src(pg, &cfg->src_entries[i]); + return err; +} - return 0; +static int br_mdb_replace_group_srcs(const struct br_mdb_config *cfg, + struct net_bridge_port_group *pg, + struct net_bridge_mcast *brmctx, + struct netlink_ext_ack *extack) +{ + struct net_bridge_group_src *ent; + struct hlist_node *tmp; + int err; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags |= BR_SGRP_F_DELETE; + + err = br_mdb_add_group_srcs(cfg, pg, brmctx, extack); + if (err) + goto err_clear_delete; + + hlist_for_each_entry_safe(ent, tmp, &pg->src_list, node) { + if (ent->flags & BR_SGRP_F_DELETE) + br_multicast_del_group_src(ent, false); } + return 0; + +err_clear_delete: + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags &= ~BR_SGRP_F_DELETE; + return err; +} + +static int br_mdb_replace_group_star_g(const struct br_mdb_config *cfg, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + struct net_bridge_mcast *brmctx, + unsigned char flags, + struct netlink_ext_ack *extack) +{ + unsigned long now = jiffies; + int err; + + err = br_mdb_replace_group_srcs(cfg, pg, brmctx, extack); + if (err) + return err; + + pg->flags = flags; + pg->filter_mode = cfg->filter_mode; + pg->rt_protocol = cfg->rt_protocol; + if (!(flags & MDB_PG_FLAGS_PERMANENT) && + cfg->filter_mode == MCAST_EXCLUDE) + mod_timer(&pg->timer, + now + brmctx->multicast_membership_interval); + else + del_timer(&pg->timer); + + br_mdb_notify(cfg->br->dev, mp, pg, RTM_NEWMDB); + + if (br_multicast_should_handle_mode(brmctx, cfg->group.proto)) + br_multicast_star_g_handle_mode(pg, cfg->filter_mode); + + return 0; +} + +static int br_mdb_add_group_star_g(const struct br_mdb_config *cfg, + struct net_bridge_mdb_entry *mp, + struct net_bridge_mcast *brmctx, + unsigned char flags, + struct netlink_ext_ack *extack) +{ + struct net_bridge_port_group __rcu **pp; + struct net_bridge_port_group *p; + unsigned long now = jiffies; + int err; + for (pp = &mp->ports; - (p = mlock_dereference(*pp, br)) != NULL; + (p = mlock_dereference(*pp, cfg->br)) != NULL; pp = &p->next) { - if (p->key.port == port) { - NL_SET_ERR_MSG_MOD(extack, "Group is already joined by port"); - return -EEXIST; + if (p->key.port == cfg->p) { + if (!(cfg->nlflags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG_MOD(extack, "(*, G) group is already joined by port"); + return -EEXIST; + } + return br_mdb_replace_group_star_g(cfg, mp, p, brmctx, + flags, extack); } - if ((unsigned long)p->key.port < (unsigned long)port) + if ((unsigned long)p->key.port < (unsigned long)cfg->p) break; } - filter_mode = br_multicast_is_star_g(&group) ? MCAST_EXCLUDE : - MCAST_INCLUDE; - - if (entry->state == MDB_PERMANENT) - flags |= MDB_PG_FLAGS_PERMANENT; - - p = br_multicast_new_port_group(port, &group, *pp, flags, NULL, - filter_mode, RTPROT_STATIC); + p = br_multicast_new_port_group(cfg->p, &cfg->group, *pp, flags, NULL, + cfg->filter_mode, cfg->rt_protocol); if (unlikely(!p)) { - NL_SET_ERR_MSG_MOD(extack, "Couldn't allocate new port group"); + NL_SET_ERR_MSG_MOD(extack, "Couldn't allocate new (*, G) port group"); return -ENOMEM; } + + err = br_mdb_add_group_srcs(cfg, p, brmctx, extack); + if (err) + goto err_del_port_group; + rcu_assign_pointer(*pp, p); - if (entry->state == MDB_TEMPORARY) + if (!(flags & MDB_PG_FLAGS_PERMANENT) && + cfg->filter_mode == MCAST_EXCLUDE) mod_timer(&p->timer, now + brmctx->multicast_membership_interval); - br_mdb_notify(br->dev, mp, p, RTM_NEWMDB); - /* if we are adding a new EXCLUDE port group (*,G) it needs to be also - * added to all S,G entries for proper replication, if we are adding - * a new INCLUDE port (S,G) then all of *,G EXCLUDE ports need to be - * added to it for proper replication + br_mdb_notify(cfg->br->dev, mp, p, RTM_NEWMDB); + /* If we are adding a new EXCLUDE port group (*, G), it needs to be + * also added to all (S, G) entries for proper replication. */ - if (br_multicast_should_handle_mode(brmctx, group.proto)) { - switch (filter_mode) { - case MCAST_EXCLUDE: - br_multicast_star_g_handle_mode(p, MCAST_EXCLUDE); - break; - case MCAST_INCLUDE: - star_group = p->key.addr; - memset(&star_group.src, 0, sizeof(star_group.src)); - star_mp = br_mdb_ip_get(br, &star_group); - if (star_mp) - br_multicast_sg_add_exclude_ports(star_mp, p); - break; + if (br_multicast_should_handle_mode(brmctx, cfg->group.proto) && + cfg->filter_mode == MCAST_EXCLUDE) + br_multicast_star_g_handle_mode(p, MCAST_EXCLUDE); + + return 0; + +err_del_port_group: + hlist_del_init(&p->mglist); + kfree(p); + return err; +} + +static int br_mdb_add_group(const struct br_mdb_config *cfg, + struct netlink_ext_ack *extack) +{ + struct br_mdb_entry *entry = cfg->entry; + struct net_bridge_port *port = cfg->p; + struct net_bridge_mdb_entry *mp; + struct net_bridge *br = cfg->br; + struct net_bridge_mcast *brmctx; + struct br_ip group = cfg->group; + unsigned char flags = 0; + + brmctx = __br_mdb_choose_context(br, entry, extack); + if (!brmctx) + return -EINVAL; + + mp = br_multicast_new_group(br, &group); + if (IS_ERR(mp)) + return PTR_ERR(mp); + + /* host join */ + if (!port) { + if (mp->host_joined) { + NL_SET_ERR_MSG_MOD(extack, "Group is already joined by host"); + return -EEXIST; } + + br_multicast_host_join(brmctx, mp, false); + br_mdb_notify(br->dev, mp, NULL, RTM_NEWMDB); + + return 0; } - return 0; + if (entry->state == MDB_PERMANENT) + flags |= MDB_PG_FLAGS_PERMANENT; + + if (br_multicast_is_star_g(&group)) + return br_mdb_add_group_star_g(cfg, mp, brmctx, flags, extack); + else + return br_mdb_add_group_sg(cfg, mp, brmctx, flags, extack); } -static int __br_mdb_add(struct net *net, struct net_bridge *br, - struct net_bridge_port *p, - struct br_mdb_entry *entry, - struct nlattr **mdb_attrs, +static int __br_mdb_add(const struct br_mdb_config *cfg, struct netlink_ext_ack *extack) { int ret; - spin_lock_bh(&br->multicast_lock); - ret = br_mdb_add_group(br, p, entry, mdb_attrs, extack); - spin_unlock_bh(&br->multicast_lock); + spin_lock_bh(&cfg->br->multicast_lock); + ret = br_mdb_add_group(cfg, extack); + spin_unlock_bh(&cfg->br->multicast_lock); return ret; } -static int br_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, - struct netlink_ext_ack *extack) +static int br_mdb_config_src_entry_init(struct nlattr *src_entry, + struct br_mdb_src_entry *src, + __be16 proto, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[MDBE_SRCATTR_MAX + 1]; + int err; + + err = nla_parse_nested(tb, MDBE_SRCATTR_MAX, src_entry, + br_mdbe_src_list_entry_pol, extack); + if (err) + return err; + + if (NL_REQ_ATTR_CHECK(extack, src_entry, tb, MDBE_SRCATTR_ADDRESS)) + return -EINVAL; + + if (!is_valid_mdb_source(tb[MDBE_SRCATTR_ADDRESS], proto, extack)) + return -EINVAL; + + src->addr.proto = proto; + nla_memcpy(&src->addr.src, tb[MDBE_SRCATTR_ADDRESS], + nla_len(tb[MDBE_SRCATTR_ADDRESS])); + + return 0; +} + +static int br_mdb_config_src_list_init(struct nlattr *src_list, + struct br_mdb_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nlattr *src_entry; + int rem, err; + int i = 0; + + nla_for_each_nested(src_entry, src_list, rem) + cfg->num_src_entries++; + + if (cfg->num_src_entries >= PG_SRC_ENT_LIMIT) { + NL_SET_ERR_MSG_FMT_MOD(extack, "Exceeded maximum number of source entries (%u)", + PG_SRC_ENT_LIMIT - 1); + return -EINVAL; + } + + cfg->src_entries = kcalloc(cfg->num_src_entries, + sizeof(struct br_mdb_src_entry), GFP_KERNEL); + if (!cfg->src_entries) + return -ENOMEM; + + nla_for_each_nested(src_entry, src_list, rem) { + err = br_mdb_config_src_entry_init(src_entry, + &cfg->src_entries[i], + cfg->entry->addr.proto, + extack); + if (err) + goto err_src_entry_init; + i++; + } + + return 0; + +err_src_entry_init: + kfree(cfg->src_entries); + return err; +} + +static void br_mdb_config_src_list_fini(struct br_mdb_config *cfg) +{ + kfree(cfg->src_entries); +} + +static int br_mdb_config_attrs_init(struct nlattr *set_attrs, + struct br_mdb_config *cfg, + struct netlink_ext_ack *extack) { struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1]; - struct net *net = sock_net(skb->sk); - struct net_bridge_vlan_group *vg; - struct net_bridge_port *p = NULL; - struct net_device *dev, *pdev; - struct br_mdb_entry *entry; - struct net_bridge_vlan *v; - struct net_bridge *br; int err; - err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack); - if (err < 0) + err = nla_parse_nested(mdb_attrs, MDBE_ATTR_MAX, set_attrs, + br_mdbe_attrs_pol, extack); + if (err) + return err; + + if (mdb_attrs[MDBE_ATTR_SOURCE] && + !is_valid_mdb_source(mdb_attrs[MDBE_ATTR_SOURCE], + cfg->entry->addr.proto, extack)) + return -EINVAL; + + __mdb_entry_to_br_ip(cfg->entry, &cfg->group, mdb_attrs); + + if (mdb_attrs[MDBE_ATTR_GROUP_MODE]) { + if (!cfg->p) { + NL_SET_ERR_MSG_MOD(extack, "Filter mode cannot be set for host groups"); + return -EINVAL; + } + if (!br_multicast_is_star_g(&cfg->group)) { + NL_SET_ERR_MSG_MOD(extack, "Filter mode can only be set for (*, G) entries"); + return -EINVAL; + } + cfg->filter_mode = nla_get_u8(mdb_attrs[MDBE_ATTR_GROUP_MODE]); + } else { + cfg->filter_mode = MCAST_EXCLUDE; + } + + if (mdb_attrs[MDBE_ATTR_SRC_LIST]) { + if (!cfg->p) { + NL_SET_ERR_MSG_MOD(extack, "Source list cannot be set for host groups"); + return -EINVAL; + } + if (!br_multicast_is_star_g(&cfg->group)) { + NL_SET_ERR_MSG_MOD(extack, "Source list can only be set for (*, G) entries"); + return -EINVAL; + } + if (!mdb_attrs[MDBE_ATTR_GROUP_MODE]) { + NL_SET_ERR_MSG_MOD(extack, "Source list cannot be set without filter mode"); + return -EINVAL; + } + err = br_mdb_config_src_list_init(mdb_attrs[MDBE_ATTR_SRC_LIST], + cfg, extack); + if (err) + return err; + } + + if (!cfg->num_src_entries && cfg->filter_mode == MCAST_INCLUDE) { + NL_SET_ERR_MSG_MOD(extack, "Cannot add (*, G) INCLUDE with an empty source list"); + return -EINVAL; + } + + if (mdb_attrs[MDBE_ATTR_RTPROT]) { + if (!cfg->p) { + NL_SET_ERR_MSG_MOD(extack, "Protocol cannot be set for host groups"); + return -EINVAL; + } + cfg->rt_protocol = nla_get_u8(mdb_attrs[MDBE_ATTR_RTPROT]); + } + + return 0; +} + +static int br_mdb_config_init(struct net *net, const struct nlmsghdr *nlh, + struct br_mdb_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[MDBA_SET_ENTRY_MAX + 1]; + struct br_port_msg *bpm; + struct net_device *dev; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb, + MDBA_SET_ENTRY_MAX, NULL, extack); + if (err) return err; - br = netdev_priv(dev); + memset(cfg, 0, sizeof(*cfg)); + cfg->filter_mode = MCAST_EXCLUDE; + cfg->rt_protocol = RTPROT_STATIC; + cfg->nlflags = nlh->nlmsg_flags; + + bpm = nlmsg_data(nlh); + if (!bpm->ifindex) { + NL_SET_ERR_MSG_MOD(extack, "Invalid bridge ifindex"); + return -EINVAL; + } - if (!netif_running(br->dev)) { + dev = __dev_get_by_index(net, bpm->ifindex); + if (!dev) { + NL_SET_ERR_MSG_MOD(extack, "Bridge device doesn't exist"); + return -ENODEV; + } + + if (!netif_is_bridge_master(dev)) { + NL_SET_ERR_MSG_MOD(extack, "Device is not a bridge"); + return -EOPNOTSUPP; + } + + cfg->br = netdev_priv(dev); + + if (!netif_running(cfg->br->dev)) { NL_SET_ERR_MSG_MOD(extack, "Bridge device is not running"); return -EINVAL; } - if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) { + if (!br_opt_get(cfg->br, BROPT_MULTICAST_ENABLED)) { NL_SET_ERR_MSG_MOD(extack, "Bridge's multicast processing is disabled"); return -EINVAL; } - if (entry->ifindex != br->dev->ifindex) { - pdev = __dev_get_by_index(net, entry->ifindex); + if (NL_REQ_ATTR_CHECK(extack, NULL, tb, MDBA_SET_ENTRY)) { + NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY attribute"); + return -EINVAL; + } + if (nla_len(tb[MDBA_SET_ENTRY]) != sizeof(struct br_mdb_entry)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid MDBA_SET_ENTRY attribute length"); + return -EINVAL; + } + + cfg->entry = nla_data(tb[MDBA_SET_ENTRY]); + if (!is_valid_mdb_entry(cfg->entry, extack)) + return -EINVAL; + + if (cfg->entry->ifindex != cfg->br->dev->ifindex) { + struct net_device *pdev; + + pdev = __dev_get_by_index(net, cfg->entry->ifindex); if (!pdev) { NL_SET_ERR_MSG_MOD(extack, "Port net device doesn't exist"); return -ENODEV; } - p = br_port_get_rtnl(pdev); - if (!p) { + cfg->p = br_port_get_rtnl(pdev); + if (!cfg->p) { NL_SET_ERR_MSG_MOD(extack, "Net device is not a bridge port"); return -EINVAL; } - if (p->br != br) { + if (cfg->p->br != cfg->br) { NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device"); return -EINVAL; } - if (p->state == BR_STATE_DISABLED && entry->state != MDB_PERMANENT) { + } + + if (tb[MDBA_SET_ENTRY_ATTRS]) + return br_mdb_config_attrs_init(tb[MDBA_SET_ENTRY_ATTRS], cfg, + extack); + else + __mdb_entry_to_br_ip(cfg->entry, &cfg->group, NULL); + + return 0; +} + +static void br_mdb_config_fini(struct br_mdb_config *cfg) +{ + br_mdb_config_src_list_fini(cfg); +} + +static int br_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + struct br_mdb_config cfg; + int err; + + err = br_mdb_config_init(net, nlh, &cfg, extack); + if (err) + return err; + + err = -EINVAL; + /* host join errors which can happen before creating the group */ + if (!cfg.p && !br_group_is_l2(&cfg.group)) { + /* don't allow any flags for host-joined IP groups */ + if (cfg.entry->state) { + NL_SET_ERR_MSG_MOD(extack, "Flags are not allowed for host groups"); + goto out; + } + if (!br_multicast_is_star_g(&cfg.group)) { + NL_SET_ERR_MSG_MOD(extack, "Groups with sources cannot be manually host joined"); + goto out; + } + } + + if (br_group_is_l2(&cfg.group) && cfg.entry->state != MDB_PERMANENT) { + NL_SET_ERR_MSG_MOD(extack, "Only permanent L2 entries allowed"); + goto out; + } + + if (cfg.p) { + if (cfg.p->state == BR_STATE_DISABLED && cfg.entry->state != MDB_PERMANENT) { NL_SET_ERR_MSG_MOD(extack, "Port is in disabled state and entry is not permanent"); - return -EINVAL; + goto out; } - vg = nbp_vlan_group(p); + vg = nbp_vlan_group(cfg.p); } else { - vg = br_vlan_group(br); + vg = br_vlan_group(cfg.br); } /* If vlan filtering is enabled and VLAN is not specified * install mdb entry on all vlans configured on the port. */ - if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) { + if (br_vlan_enabled(cfg.br->dev) && vg && cfg.entry->vid == 0) { list_for_each_entry(v, &vg->vlan_list, vlist) { - entry->vid = v->vid; - err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack); + cfg.entry->vid = v->vid; + cfg.group.vid = v->vid; + err = __br_mdb_add(&cfg, extack); if (err) break; } } else { - err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack); + err = __br_mdb_add(&cfg, extack); } +out: + br_mdb_config_fini(&cfg); return err; } -static int __br_mdb_del(struct net_bridge *br, struct br_mdb_entry *entry, - struct nlattr **mdb_attrs) +static int __br_mdb_del(const struct br_mdb_config *cfg) { + struct br_mdb_entry *entry = cfg->entry; + struct net_bridge *br = cfg->br; struct net_bridge_mdb_entry *mp; struct net_bridge_port_group *p; struct net_bridge_port_group __rcu **pp; - struct br_ip ip; + struct br_ip ip = cfg->group; int err = -EINVAL; - if (!netif_running(br->dev) || !br_opt_get(br, BROPT_MULTICAST_ENABLED)) - return -EINVAL; - - __mdb_entry_to_br_ip(entry, &ip, mdb_attrs); - spin_lock_bh(&br->multicast_lock); mp = br_mdb_ip_get(br, &ip); if (!mp) @@ -1099,53 +1499,35 @@ unlock: static int br_mdb_del(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { - struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1]; struct net *net = sock_net(skb->sk); struct net_bridge_vlan_group *vg; - struct net_bridge_port *p = NULL; - struct net_device *dev, *pdev; - struct br_mdb_entry *entry; struct net_bridge_vlan *v; - struct net_bridge *br; + struct br_mdb_config cfg; int err; - err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack); - if (err < 0) + err = br_mdb_config_init(net, nlh, &cfg, extack); + if (err) return err; - br = netdev_priv(dev); - - if (entry->ifindex != br->dev->ifindex) { - pdev = __dev_get_by_index(net, entry->ifindex); - if (!pdev) - return -ENODEV; - - p = br_port_get_rtnl(pdev); - if (!p) { - NL_SET_ERR_MSG_MOD(extack, "Net device is not a bridge port"); - return -EINVAL; - } - if (p->br != br) { - NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device"); - return -EINVAL; - } - vg = nbp_vlan_group(p); - } else { - vg = br_vlan_group(br); - } + if (cfg.p) + vg = nbp_vlan_group(cfg.p); + else + vg = br_vlan_group(cfg.br); /* If vlan filtering is enabled and VLAN is not specified * delete mdb entry on all vlans configured on the port. */ - if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) { + if (br_vlan_enabled(cfg.br->dev) && vg && cfg.entry->vid == 0) { list_for_each_entry(v, &vg->vlan_list, vlist) { - entry->vid = v->vid; - err = __br_mdb_del(br, entry, mdb_attrs); + cfg.entry->vid = v->vid; + cfg.group.vid = v->vid; + err = __br_mdb_del(&cfg); } } else { - err = __br_mdb_del(br, entry, mdb_attrs); + err = __br_mdb_del(&cfg); } + br_mdb_config_fini(&cfg); return err; } diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c index db4f2641d1cd..dea1ee1bd095 100644 --- a/net/bridge/br_multicast.c +++ b/net/bridge/br_multicast.c @@ -552,7 +552,8 @@ static void br_multicast_fwd_src_remove(struct net_bridge_group_src *src, continue; if (p->rt_protocol != RTPROT_KERNEL && - (p->flags & MDB_PG_FLAGS_PERMANENT)) + (p->flags & MDB_PG_FLAGS_PERMANENT) && + !(src->flags & BR_SGRP_F_USER_ADDED)) break; if (fastleave) @@ -605,7 +606,7 @@ static void br_multicast_destroy_mdb_entry(struct net_bridge_mcast_gc *gc) WARN_ON(!hlist_unhashed(&mp->mdb_node)); WARN_ON(mp->ports); - del_timer_sync(&mp->timer); + timer_shutdown_sync(&mp->timer); kfree_rcu(mp, rcu); } @@ -646,22 +647,27 @@ static void br_multicast_destroy_group_src(struct net_bridge_mcast_gc *gc) src = container_of(gc, struct net_bridge_group_src, mcast_gc); WARN_ON(!hlist_unhashed(&src->node)); - del_timer_sync(&src->timer); + timer_shutdown_sync(&src->timer); kfree_rcu(src, rcu); } -void br_multicast_del_group_src(struct net_bridge_group_src *src, - bool fastleave) +void __br_multicast_del_group_src(struct net_bridge_group_src *src) { struct net_bridge *br = src->pg->key.port->br; - br_multicast_fwd_src_remove(src, fastleave); hlist_del_init_rcu(&src->node); src->pg->src_ents--; hlist_add_head(&src->mcast_gc.gc_node, &br->mcast_gc_list); queue_work(system_long_wq, &br->mcast_gc_work); } +void br_multicast_del_group_src(struct net_bridge_group_src *src, + bool fastleave) +{ + br_multicast_fwd_src_remove(src, fastleave); + __br_multicast_del_group_src(src); +} + static void br_multicast_destroy_port_group(struct net_bridge_mcast_gc *gc) { struct net_bridge_port_group *pg; @@ -670,8 +676,8 @@ static void br_multicast_destroy_port_group(struct net_bridge_mcast_gc *gc) WARN_ON(!hlist_unhashed(&pg->mglist)); WARN_ON(!hlist_empty(&pg->src_list)); - del_timer_sync(&pg->rexmit_timer); - del_timer_sync(&pg->timer); + timer_shutdown_sync(&pg->rexmit_timer); + timer_shutdown_sync(&pg->timer); kfree_rcu(pg, rcu); } @@ -1232,7 +1238,7 @@ br_multicast_find_group_src(struct net_bridge_port_group *pg, struct br_ip *ip) return NULL; } -static struct net_bridge_group_src * +struct net_bridge_group_src * br_multicast_new_group_src(struct net_bridge_port_group *pg, struct br_ip *src_ip) { struct net_bridge_group_src *grp_src; @@ -1273,7 +1279,7 @@ br_multicast_new_group_src(struct net_bridge_port_group *pg, struct br_ip *src_i struct net_bridge_port_group *br_multicast_new_port_group( struct net_bridge_port *port, - struct br_ip *group, + const struct br_ip *group, struct net_bridge_port_group __rcu *next, unsigned char flags, const unsigned char *src, @@ -2669,7 +2675,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge_mcast *brmctx, if (!pmctx || igmpv2) continue; - spin_lock_bh(&brmctx->br->multicast_lock); + spin_lock(&brmctx->br->multicast_lock); if (!br_multicast_ctx_should_use(brmctx, pmctx)) goto unlock_continue; @@ -2717,7 +2723,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge_mcast *brmctx, if (changed) br_mdb_notify(brmctx->br->dev, mdst, pg, RTM_NEWMDB); unlock_continue: - spin_unlock_bh(&brmctx->br->multicast_lock); + spin_unlock(&brmctx->br->multicast_lock); } return err; @@ -2807,7 +2813,7 @@ static int br_ip6_multicast_mld2_report(struct net_bridge_mcast *brmctx, if (!pmctx || mldv1) continue; - spin_lock_bh(&brmctx->br->multicast_lock); + spin_lock(&brmctx->br->multicast_lock); if (!br_multicast_ctx_should_use(brmctx, pmctx)) goto unlock_continue; @@ -2859,7 +2865,7 @@ static int br_ip6_multicast_mld2_report(struct net_bridge_mcast *brmctx, if (changed) br_mdb_notify(brmctx->br->dev, mdst, pg, RTM_NEWMDB); unlock_continue: - spin_unlock_bh(&brmctx->br->multicast_lock); + spin_unlock(&brmctx->br->multicast_lock); } return err; @@ -4899,9 +4905,9 @@ void br_multicast_get_stats(const struct net_bridge *br, unsigned int start; do { - start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + start = u64_stats_fetch_begin(&cpu_stats->syncp); memcpy(&temp, &cpu_stats->mstats, sizeof(temp)); - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); mcast_stats_add_dir(tdst.igmp_v1queries, temp.igmp_v1queries); mcast_stats_add_dir(tdst.igmp_v2queries, temp.igmp_v2queries); diff --git a/net/bridge/br_multicast_eht.c b/net/bridge/br_multicast_eht.c index f91c071d1608..c126aa4e7551 100644 --- a/net/bridge/br_multicast_eht.c +++ b/net/bridge/br_multicast_eht.c @@ -142,7 +142,7 @@ static void br_multicast_destroy_eht_set_entry(struct net_bridge_mcast_gc *gc) set_h = container_of(gc, struct net_bridge_group_eht_set_entry, mcast_gc); WARN_ON(!RB_EMPTY_NODE(&set_h->rb_node)); - del_timer_sync(&set_h->timer); + timer_shutdown_sync(&set_h->timer); kfree(set_h); } @@ -154,7 +154,7 @@ static void br_multicast_destroy_eht_set(struct net_bridge_mcast_gc *gc) WARN_ON(!RB_EMPTY_NODE(&eht_set->rb_node)); WARN_ON(!RB_EMPTY_ROOT(&eht_set->entry_tree)); - del_timer_sync(&eht_set->timer); + timer_shutdown_sync(&eht_set->timer); kfree(eht_set); } diff --git a/net/bridge/br_netfilter_hooks.c b/net/bridge/br_netfilter_hooks.c index f20f4373ff40..9554abcfd5b4 100644 --- a/net/bridge/br_netfilter_hooks.c +++ b/net/bridge/br_netfilter_hooks.c @@ -871,6 +871,7 @@ static unsigned int ip_sabotage_in(void *priv, if (nf_bridge && !nf_bridge->in_prerouting && !netif_is_l3_master(skb->dev) && !netif_is_l3_slave(skb->dev)) { + nf_bridge_info_free(skb); state->okfn(state->net, state->sk, skb); return NF_STOLEN; } diff --git a/net/bridge/br_netlink.c b/net/bridge/br_netlink.c index d087fd4c784a..4316cc82ae17 100644 --- a/net/bridge/br_netlink.c +++ b/net/bridge/br_netlink.c @@ -188,6 +188,7 @@ static inline size_t br_port_info_size(void) + nla_total_size(1) /* IFLA_BRPORT_NEIGH_SUPPRESS */ + nla_total_size(1) /* IFLA_BRPORT_ISOLATED */ + nla_total_size(1) /* IFLA_BRPORT_LOCKED */ + + nla_total_size(1) /* IFLA_BRPORT_MAB */ + nla_total_size(sizeof(struct ifla_bridge_id)) /* IFLA_BRPORT_ROOT_ID */ + nla_total_size(sizeof(struct ifla_bridge_id)) /* IFLA_BRPORT_BRIDGE_ID */ + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_DESIGNATED_PORT */ @@ -274,7 +275,8 @@ static int br_port_fill_attrs(struct sk_buff *skb, nla_put_u8(skb, IFLA_BRPORT_MRP_IN_OPEN, !!(p->flags & BR_MRP_LOST_IN_CONT)) || nla_put_u8(skb, IFLA_BRPORT_ISOLATED, !!(p->flags & BR_ISOLATED)) || - nla_put_u8(skb, IFLA_BRPORT_LOCKED, !!(p->flags & BR_PORT_LOCKED))) + nla_put_u8(skb, IFLA_BRPORT_LOCKED, !!(p->flags & BR_PORT_LOCKED)) || + nla_put_u8(skb, IFLA_BRPORT_MAB, !!(p->flags & BR_PORT_MAB))) return -EMSGSIZE; timerval = br_timer_value(&p->message_age_timer); @@ -876,6 +878,7 @@ static const struct nla_policy br_port_policy[IFLA_BRPORT_MAX + 1] = { [IFLA_BRPORT_NEIGH_SUPPRESS] = { .type = NLA_U8 }, [IFLA_BRPORT_ISOLATED] = { .type = NLA_U8 }, [IFLA_BRPORT_LOCKED] = { .type = NLA_U8 }, + [IFLA_BRPORT_MAB] = { .type = NLA_U8 }, [IFLA_BRPORT_BACKUP_PORT] = { .type = NLA_U32 }, [IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT] = { .type = NLA_U32 }, }; @@ -943,6 +946,22 @@ static int br_setport(struct net_bridge_port *p, struct nlattr *tb[], br_set_port_flag(p, tb, IFLA_BRPORT_NEIGH_SUPPRESS, BR_NEIGH_SUPPRESS); br_set_port_flag(p, tb, IFLA_BRPORT_ISOLATED, BR_ISOLATED); br_set_port_flag(p, tb, IFLA_BRPORT_LOCKED, BR_PORT_LOCKED); + br_set_port_flag(p, tb, IFLA_BRPORT_MAB, BR_PORT_MAB); + + if ((p->flags & BR_PORT_MAB) && + (!(p->flags & BR_PORT_LOCKED) || !(p->flags & BR_LEARNING))) { + NL_SET_ERR_MSG(extack, "Bridge port must be locked and have learning enabled when MAB is enabled"); + p->flags = old_flags; + return -EINVAL; + } else if (!(p->flags & BR_PORT_MAB) && (old_flags & BR_PORT_MAB)) { + struct net_bridge_fdb_flush_desc desc = { + .flags = BIT(BR_FDB_LOCKED), + .flags_mask = BIT(BR_FDB_LOCKED), + .port_ifindex = p->dev->ifindex, + }; + + br_fdb_flush(p->br, &desc); + } changed_mask = old_flags ^ p->flags; diff --git a/net/bridge/br_private.h b/net/bridge/br_private.h index 06e5f6faa431..15ef7fd508ee 100644 --- a/net/bridge/br_private.h +++ b/net/bridge/br_private.h @@ -92,6 +92,23 @@ struct bridge_mcast_stats { struct br_mcast_stats mstats; struct u64_stats_sync syncp; }; + +struct br_mdb_src_entry { + struct br_ip addr; +}; + +struct br_mdb_config { + struct net_bridge *br; + struct net_bridge_port *p; + struct br_mdb_entry *entry; + struct br_ip group; + bool src_entry; + u8 filter_mode; + u16 nlflags; + struct br_mdb_src_entry *src_entries; + int num_src_entries; + u8 rt_protocol; +}; #endif /* net_bridge_mcast_port must be always defined due to forwarding stubs */ @@ -251,7 +268,8 @@ enum { BR_FDB_ADDED_BY_EXT_LEARN, BR_FDB_OFFLOADED, BR_FDB_NOTIFY, - BR_FDB_NOTIFY_INACTIVE + BR_FDB_NOTIFY_INACTIVE, + BR_FDB_LOCKED, }; struct net_bridge_fdb_key { @@ -292,6 +310,7 @@ struct net_bridge_fdb_flush_desc { #define BR_SGRP_F_DELETE BIT(0) #define BR_SGRP_F_SEND BIT(1) #define BR_SGRP_F_INSTALLED BIT(2) +#define BR_SGRP_F_USER_ADDED BIT(3) struct net_bridge_mcast_gc { struct hlist_node gc_node; @@ -810,7 +829,7 @@ int br_fdb_sync_static(struct net_bridge *br, struct net_bridge_port *p); void br_fdb_unsync_static(struct net_bridge *br, struct net_bridge_port *p); int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, const unsigned char *addr, u16 vid, - bool swdev_notify); + bool locked, bool swdev_notify); int br_fdb_external_learn_del(struct net_bridge *br, struct net_bridge_port *p, const unsigned char *addr, u16 vid, bool swdev_notify); @@ -933,7 +952,8 @@ br_mdb_ip_get(struct net_bridge *br, struct br_ip *dst); struct net_bridge_mdb_entry * br_multicast_new_group(struct net_bridge *br, struct br_ip *group); struct net_bridge_port_group * -br_multicast_new_port_group(struct net_bridge_port *port, struct br_ip *group, +br_multicast_new_port_group(struct net_bridge_port *port, + const struct br_ip *group, struct net_bridge_port_group __rcu *next, unsigned char flags, const unsigned char *src, u8 filter_mode, u8 rt_protocol); @@ -965,6 +985,10 @@ void br_multicast_sg_add_exclude_ports(struct net_bridge_mdb_entry *star_mp, struct net_bridge_port_group *sg); struct net_bridge_group_src * br_multicast_find_group_src(struct net_bridge_port_group *pg, struct br_ip *ip); +struct net_bridge_group_src * +br_multicast_new_group_src(struct net_bridge_port_group *pg, + struct br_ip *src_ip); +void __br_multicast_del_group_src(struct net_bridge_group_src *src); void br_multicast_del_group_src(struct net_bridge_group_src *src, bool fastleave); void br_multicast_ctx_init(struct net_bridge *br, diff --git a/net/bridge/br_switchdev.c b/net/bridge/br_switchdev.c index 8f3d76c751dd..7eb6fd5bb917 100644 --- a/net/bridge/br_switchdev.c +++ b/net/bridge/br_switchdev.c @@ -71,7 +71,7 @@ bool nbp_switchdev_allowed_egress(const struct net_bridge_port *p, } /* Flags that can be offloaded to hardware */ -#define BR_PORT_FLAGS_HW_OFFLOAD (BR_LEARNING | BR_FLOOD | \ +#define BR_PORT_FLAGS_HW_OFFLOAD (BR_LEARNING | BR_FLOOD | BR_PORT_MAB | \ BR_MCAST_FLOOD | BR_BCAST_FLOOD | BR_PORT_LOCKED | \ BR_HAIRPIN_MODE | BR_ISOLATED | BR_MULTICAST_TO_UNICAST) @@ -136,6 +136,7 @@ static void br_switchdev_fdb_populate(struct net_bridge *br, item->added_by_user = test_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); item->offloaded = test_bit(BR_FDB_OFFLOADED, &fdb->flags); item->is_local = test_bit(BR_FDB_LOCAL, &fdb->flags); + item->locked = false; item->info.dev = (!p || item->is_local) ? br->dev : p->dev; item->info.ctx = ctx; } @@ -146,6 +147,9 @@ br_switchdev_fdb_notify(struct net_bridge *br, { struct switchdev_notifier_fdb_info item; + if (test_bit(BR_FDB_LOCKED, &fdb->flags)) + return; + br_switchdev_fdb_populate(br, &item, fdb, NULL); switch (type) { diff --git a/net/bridge/br_vlan.c b/net/bridge/br_vlan.c index 6e53dc991409..bc75fa1e4666 100644 --- a/net/bridge/br_vlan.c +++ b/net/bridge/br_vlan.c @@ -959,6 +959,8 @@ int __br_vlan_set_proto(struct net_bridge *br, __be16 proto, list_for_each_entry(p, &br->port_list, list) { vg = nbp_vlan_group(p); list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; err = vlan_vid_add(p->dev, proto, vlan->vid); if (err) goto err_filt; @@ -973,8 +975,11 @@ int __br_vlan_set_proto(struct net_bridge *br, __be16 proto, /* Delete VLANs for the old proto from the device filter. */ list_for_each_entry(p, &br->port_list, list) { vg = nbp_vlan_group(p); - list_for_each_entry(vlan, &vg->vlan_list, vlist) + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; vlan_vid_del(p->dev, oldproto, vlan->vid); + } } return 0; @@ -983,13 +988,19 @@ err_filt: attr.u.vlan_protocol = ntohs(oldproto); switchdev_port_attr_set(br->dev, &attr, NULL); - list_for_each_entry_continue_reverse(vlan, &vg->vlan_list, vlist) + list_for_each_entry_continue_reverse(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; vlan_vid_del(p->dev, proto, vlan->vid); + } list_for_each_entry_continue_reverse(p, &br->port_list, list) { vg = nbp_vlan_group(p); - list_for_each_entry(vlan, &vg->vlan_list, vlist) + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; vlan_vid_del(p->dev, proto, vlan->vid); + } } return err; @@ -1378,12 +1389,12 @@ void br_vlan_get_stats(const struct net_bridge_vlan *v, cpu_stats = per_cpu_ptr(v->stats, i); do { - start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + start = u64_stats_fetch_begin(&cpu_stats->syncp); rxpackets = u64_stats_read(&cpu_stats->rx_packets); rxbytes = u64_stats_read(&cpu_stats->rx_bytes); txbytes = u64_stats_read(&cpu_stats->tx_bytes); txpackets = u64_stats_read(&cpu_stats->tx_packets); - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); u64_stats_add(&stats->rx_packets, rxpackets); u64_stats_add(&stats->rx_bytes, rxbytes); diff --git a/net/bridge/netfilter/nf_conntrack_bridge.c b/net/bridge/netfilter/nf_conntrack_bridge.c index 73242962be5d..5c5dd437f1c2 100644 --- a/net/bridge/netfilter/nf_conntrack_bridge.c +++ b/net/bridge/netfilter/nf_conntrack_bridge.c @@ -366,42 +366,12 @@ static int nf_ct_bridge_refrag_post(struct net *net, struct sock *sk, return br_dev_queue_push_xmit(net, sk, skb); } -static unsigned int nf_ct_bridge_confirm(struct sk_buff *skb) -{ - enum ip_conntrack_info ctinfo; - struct nf_conn *ct; - int protoff; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return nf_conntrack_confirm(skb); - - switch (skb->protocol) { - case htons(ETH_P_IP): - protoff = skb_network_offset(skb) + ip_hdrlen(skb); - break; - case htons(ETH_P_IPV6): { - unsigned char pnum = ipv6_hdr(skb)->nexthdr; - __be16 frag_off; - - protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, - &frag_off); - if (protoff < 0 || (frag_off & htons(~0x7)) != 0) - return nf_conntrack_confirm(skb); - } - break; - default: - return NF_ACCEPT; - } - return nf_confirm(skb, protoff, ct, ctinfo); -} - static unsigned int nf_ct_bridge_post(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { int ret; - ret = nf_ct_bridge_confirm(skb); + ret = nf_confirm(priv, skb, state); if (ret != NF_ACCEPT) return ret; diff --git a/net/caif/cfctrl.c b/net/caif/cfctrl.c index cc405d8c7c30..8480684f2762 100644 --- a/net/caif/cfctrl.c +++ b/net/caif/cfctrl.c @@ -269,11 +269,15 @@ int cfctrl_linkup_request(struct cflayer *layer, default: pr_warn("Request setup of bad link type = %d\n", param->linktype); + cfpkt_destroy(pkt); return -EINVAL; } req = kzalloc(sizeof(*req), GFP_KERNEL); - if (!req) + if (!req) { + cfpkt_destroy(pkt); return -ENOMEM; + } + req->client_layer = user_layer; req->cmd = CFCTRL_CMD_LINK_SETUP; req->param = *param; diff --git a/net/caif/chnl_net.c b/net/caif/chnl_net.c index 4d63ef13a1fd..f35fc87c453a 100644 --- a/net/caif/chnl_net.c +++ b/net/caif/chnl_net.c @@ -310,9 +310,6 @@ static int chnl_net_open(struct net_device *dev) if (result == 0) { pr_debug("connect timeout\n"); - caif_disconnect_client(dev_net(dev), &priv->chnl); - priv->state = CAIF_DISCONNECTED; - pr_debug("state disconnected\n"); result = -ETIMEDOUT; goto error; } diff --git a/net/can/af_can.c b/net/can/af_can.c index 27dcdcc0b808..7343fd487dbe 100644 --- a/net/can/af_can.c +++ b/net/can/af_can.c @@ -446,7 +446,6 @@ int can_rx_register(struct net *net, struct net_device *dev, canid_t can_id, struct hlist_head *rcv_list; struct can_dev_rcv_lists *dev_rcv_lists; struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; - int err = 0; /* insert new receiver (dev,canid,mask) -> (func,data) */ @@ -481,7 +480,7 @@ int can_rx_register(struct net *net, struct net_device *dev, canid_t can_id, rcv_lists_stats->rcv_entries); spin_unlock_bh(&net->can.rcvlists_lock); - return err; + return 0; } EXPORT_SYMBOL(can_rx_register); @@ -677,7 +676,7 @@ static void can_receive(struct sk_buff *skb, struct net_device *dev) static int can_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) { - if (unlikely(dev->type != ARPHRD_CAN || (!can_is_can_skb(skb)))) { + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_can_skb(skb))) { pr_warn_once("PF_CAN: dropped non conform CAN skbuff: dev type %d, len %d\n", dev->type, skb->len); @@ -692,7 +691,7 @@ static int can_rcv(struct sk_buff *skb, struct net_device *dev, static int canfd_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) { - if (unlikely(dev->type != ARPHRD_CAN || (!can_is_canfd_skb(skb)))) { + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_canfd_skb(skb))) { pr_warn_once("PF_CAN: dropped non conform CAN FD skbuff: dev type %d, len %d\n", dev->type, skb->len); @@ -707,7 +706,7 @@ static int canfd_rcv(struct sk_buff *skb, struct net_device *dev, static int canxl_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) { - if (unlikely(dev->type != ARPHRD_CAN || (!can_is_canxl_skb(skb)))) { + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_canxl_skb(skb))) { pr_warn_once("PF_CAN: dropped non conform CAN XL skbuff: dev type %d, len %d\n", dev->type, skb->len); diff --git a/net/can/isotp.c b/net/can/isotp.c index 608f8c24ae46..fc81d77724a1 100644 --- a/net/can/isotp.c +++ b/net/can/isotp.c @@ -140,7 +140,7 @@ struct isotp_sock { canid_t rxid; ktime_t tx_gap; ktime_t lastrxcf_tstamp; - struct hrtimer rxtimer, txtimer; + struct hrtimer rxtimer, txtimer, txfrtimer; struct can_isotp_options opt; struct can_isotp_fc_options rxfc, txfc; struct can_isotp_ll_options ll; @@ -871,7 +871,7 @@ static void isotp_rcv_echo(struct sk_buff *skb, void *data) } /* start timer to send next consecutive frame with correct delay */ - hrtimer_start(&so->txtimer, so->tx_gap, HRTIMER_MODE_REL_SOFT); + hrtimer_start(&so->txfrtimer, so->tx_gap, HRTIMER_MODE_REL_SOFT); } static enum hrtimer_restart isotp_tx_timer_handler(struct hrtimer *hrtimer) @@ -879,49 +879,39 @@ static enum hrtimer_restart isotp_tx_timer_handler(struct hrtimer *hrtimer) struct isotp_sock *so = container_of(hrtimer, struct isotp_sock, txtimer); struct sock *sk = &so->sk; - enum hrtimer_restart restart = HRTIMER_NORESTART; - switch (so->tx.state) { - case ISOTP_SENDING: - - /* cfecho should be consumed by isotp_rcv_echo() here */ - if (!so->cfecho) { - /* start timeout for unlikely lost echo skb */ - hrtimer_set_expires(&so->txtimer, - ktime_add(ktime_get(), - ktime_set(ISOTP_ECHO_TIMEOUT, 0))); - restart = HRTIMER_RESTART; + /* don't handle timeouts in IDLE state */ + if (so->tx.state == ISOTP_IDLE) + return HRTIMER_NORESTART; - /* push out the next consecutive frame */ - isotp_send_cframe(so); - break; - } + /* we did not get any flow control or echo frame in time */ - /* cfecho has not been cleared in isotp_rcv_echo() */ - pr_notice_once("can-isotp: cfecho %08X timeout\n", so->cfecho); - fallthrough; + /* report 'communication error on send' */ + sk->sk_err = ECOMM; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); - case ISOTP_WAIT_FC: - case ISOTP_WAIT_FIRST_FC: + /* reset tx state */ + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); - /* we did not get any flow control frame in time */ + return HRTIMER_NORESTART; +} - /* report 'communication error on send' */ - sk->sk_err = ECOMM; - if (!sock_flag(sk, SOCK_DEAD)) - sk_error_report(sk); +static enum hrtimer_restart isotp_txfr_timer_handler(struct hrtimer *hrtimer) +{ + struct isotp_sock *so = container_of(hrtimer, struct isotp_sock, + txfrtimer); - /* reset tx state */ - so->tx.state = ISOTP_IDLE; - wake_up_interruptible(&so->wait); - break; + /* start echo timeout handling and cover below protocol error */ + hrtimer_start(&so->txtimer, ktime_set(ISOTP_ECHO_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); - default: - WARN_ONCE(1, "can-isotp: tx timer state %08X cfecho %08X\n", - so->tx.state, so->cfecho); - } + /* cfecho should be consumed by isotp_rcv_echo() here */ + if (so->tx.state == ISOTP_SENDING && !so->cfecho) + isotp_send_cframe(so); - return restart; + return HRTIMER_NORESTART; } static int isotp_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) @@ -1162,6 +1152,10 @@ static int isotp_release(struct socket *sock) /* wait for complete transmission of current pdu */ wait_event_interruptible(so->wait, so->tx.state == ISOTP_IDLE); + /* force state machines to be idle also when a signal occurred */ + so->tx.state = ISOTP_IDLE; + so->rx.state = ISOTP_IDLE; + spin_lock(&isotp_notifier_lock); while (isotp_busy_notifier == so) { spin_unlock(&isotp_notifier_lock); @@ -1194,6 +1188,7 @@ static int isotp_release(struct socket *sock) } } + hrtimer_cancel(&so->txfrtimer); hrtimer_cancel(&so->txtimer); hrtimer_cancel(&so->rxtimer); @@ -1597,6 +1592,8 @@ static int isotp_init(struct sock *sk) so->rxtimer.function = isotp_rx_timer_handler; hrtimer_init(&so->txtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); so->txtimer.function = isotp_tx_timer_handler; + hrtimer_init(&so->txfrtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); + so->txfrtimer.function = isotp_txfr_timer_handler; init_waitqueue_head(&so->wait); spin_lock_init(&so->rx_lock); diff --git a/net/can/j1939/socket.c b/net/can/j1939/socket.c index b670ba03a675..7e90f9e61d9b 100644 --- a/net/can/j1939/socket.c +++ b/net/can/j1939/socket.c @@ -189,7 +189,7 @@ activate_next: int time_ms = 0; if (err) - time_ms = 10 + prandom_u32_max(16); + time_ms = 10 + get_random_u32_below(16); j1939_tp_schedule_txtimer(first, time_ms); } diff --git a/net/can/j1939/transport.c b/net/can/j1939/transport.c index 55f29c9f9e08..fce9b9ebf13f 100644 --- a/net/can/j1939/transport.c +++ b/net/can/j1939/transport.c @@ -987,7 +987,7 @@ static int j1939_session_tx_eoma(struct j1939_session *session) /* wait for the EOMA packet to come in */ j1939_tp_set_rxtimeout(session, 1250); - netdev_dbg(session->priv->ndev, "%p: 0x%p\n", __func__, session); + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); return 0; } @@ -1092,10 +1092,6 @@ static bool j1939_session_deactivate(struct j1939_session *session) bool active; j1939_session_list_lock(priv); - /* This function should be called with a session ref-count of at - * least 2. - */ - WARN_ON_ONCE(kref_read(&session->kref) < 2); active = j1939_session_deactivate_locked(session); j1939_session_list_unlock(priv); @@ -1168,7 +1164,7 @@ static enum hrtimer_restart j1939_tp_txtimer(struct hrtimer *hrtimer) if (session->tx_retry < J1939_XTP_TX_RETRY_LIMIT) { session->tx_retry++; j1939_tp_schedule_txtimer(session, - 10 + prandom_u32_max(16)); + 10 + get_random_u32_below(16)); } else { netdev_alert(priv->ndev, "%s: 0x%p: tx retry count reached\n", __func__, session); diff --git a/net/can/raw.c b/net/can/raw.c index 3eb7d3e2b541..ba86782ba8bb 100644 --- a/net/can/raw.c +++ b/net/can/raw.c @@ -132,8 +132,8 @@ static void raw_rcv(struct sk_buff *oskb, void *data) return; /* make sure to not pass oversized frames to the socket */ - if ((can_is_canfd_skb(oskb) && !ro->fd_frames && !ro->xl_frames) || - (can_is_canxl_skb(oskb) && !ro->xl_frames)) + if ((!ro->fd_frames && can_is_canfd_skb(oskb)) || + (!ro->xl_frames && can_is_canxl_skb(oskb))) return; /* eliminate multiple filter matches for the same skb */ @@ -670,6 +670,11 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, if (copy_from_sockptr(&ro->fd_frames, optval, optlen)) return -EFAULT; + /* Enabling CAN XL includes CAN FD */ + if (ro->xl_frames && !ro->fd_frames) { + ro->fd_frames = ro->xl_frames; + return -EINVAL; + } break; case CAN_RAW_XL_FRAMES: @@ -679,6 +684,9 @@ static int raw_setsockopt(struct socket *sock, int level, int optname, if (copy_from_sockptr(&ro->xl_frames, optval, optlen)) return -EFAULT; + /* Enabling CAN XL includes CAN FD */ + if (ro->xl_frames) + ro->fd_frames = ro->xl_frames; break; case CAN_RAW_JOIN_FILTERS: @@ -786,6 +794,25 @@ static int raw_getsockopt(struct socket *sock, int level, int optname, return 0; } +static bool raw_bad_txframe(struct raw_sock *ro, struct sk_buff *skb, int mtu) +{ + /* Classical CAN -> no checks for flags and device capabilities */ + if (can_is_can_skb(skb)) + return false; + + /* CAN FD -> needs to be enabled and a CAN FD or CAN XL device */ + if (ro->fd_frames && can_is_canfd_skb(skb) && + (mtu == CANFD_MTU || can_is_canxl_dev_mtu(mtu))) + return false; + + /* CAN XL -> needs to be enabled and a CAN XL device */ + if (ro->xl_frames && can_is_canxl_skb(skb) && + can_is_canxl_dev_mtu(mtu)) + return false; + + return true; +} + static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) { struct sock *sk = sock->sk; @@ -833,20 +860,8 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) goto free_skb; err = -EINVAL; - if (ro->xl_frames && can_is_canxl_dev_mtu(dev->mtu)) { - /* CAN XL, CAN FD and Classical CAN */ - if (!can_is_canxl_skb(skb) && !can_is_canfd_skb(skb) && - !can_is_can_skb(skb)) - goto free_skb; - } else if (ro->fd_frames && dev->mtu == CANFD_MTU) { - /* CAN FD and Classical CAN */ - if (!can_is_canfd_skb(skb) && !can_is_can_skb(skb)) - goto free_skb; - } else { - /* Classical CAN */ - if (!can_is_can_skb(skb)) - goto free_skb; - } + if (raw_bad_txframe(ro, skb, dev->mtu)) + goto free_skb; sockcm_init(&sockc, sk); if (msg->msg_controllen) { @@ -857,6 +872,7 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) skb->dev = dev; skb->priority = sk->sk_priority; + skb->mark = sk->sk_mark; skb->tstamp = sockc.transmit_time; skb_setup_tx_timestamp(skb, sockc.tsflags); diff --git a/net/ceph/messenger.c b/net/ceph/messenger.c index dfa237fbd5a3..1d06e114ba3f 100644 --- a/net/ceph/messenger.c +++ b/net/ceph/messenger.c @@ -446,6 +446,7 @@ int ceph_tcp_connect(struct ceph_connection *con) if (ret) return ret; sock->sk->sk_allocation = GFP_NOFS; + sock->sk->sk_use_task_frag = false; #ifdef CONFIG_LOCKDEP lockdep_set_class(&sock->sk->sk_lock, &socket_class); diff --git a/net/ceph/messenger_v1.c b/net/ceph/messenger_v1.c index 3ddbde87e4d6..d1787d7d33ef 100644 --- a/net/ceph/messenger_v1.c +++ b/net/ceph/messenger_v1.c @@ -30,7 +30,7 @@ static int ceph_tcp_recvmsg(struct socket *sock, void *buf, size_t len) if (!buf) msg.msg_flags |= MSG_TRUNC; - iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, len); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, len); r = sock_recvmsg(sock, &msg, msg.msg_flags); if (r == -EAGAIN) r = 0; @@ -49,7 +49,7 @@ static int ceph_tcp_recvpage(struct socket *sock, struct page *page, int r; BUG_ON(page_offset + length > PAGE_SIZE); - iov_iter_bvec(&msg.msg_iter, READ, &bvec, 1, length); + iov_iter_bvec(&msg.msg_iter, ITER_DEST, &bvec, 1, length); r = sock_recvmsg(sock, &msg, msg.msg_flags); if (r == -EAGAIN) r = 0; diff --git a/net/ceph/messenger_v2.c b/net/ceph/messenger_v2.c index cc8ff81a50b7..3009028c4fa2 100644 --- a/net/ceph/messenger_v2.c +++ b/net/ceph/messenger_v2.c @@ -168,7 +168,7 @@ static int do_try_sendpage(struct socket *sock, struct iov_iter *it) bv.bv_offset, bv.bv_len, CEPH_MSG_FLAGS); } else { - iov_iter_bvec(&msg.msg_iter, WRITE, &bv, 1, bv.bv_len); + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bv, 1, bv.bv_len); ret = sock_sendmsg(sock, &msg); } if (ret <= 0) { @@ -225,7 +225,7 @@ static void reset_in_kvecs(struct ceph_connection *con) WARN_ON(iov_iter_count(&con->v2.in_iter)); con->v2.in_kvec_cnt = 0; - iov_iter_kvec(&con->v2.in_iter, READ, con->v2.in_kvecs, 0, 0); + iov_iter_kvec(&con->v2.in_iter, ITER_DEST, con->v2.in_kvecs, 0, 0); } static void set_in_bvec(struct ceph_connection *con, const struct bio_vec *bv) @@ -233,7 +233,7 @@ static void set_in_bvec(struct ceph_connection *con, const struct bio_vec *bv) WARN_ON(iov_iter_count(&con->v2.in_iter)); con->v2.in_bvec = *bv; - iov_iter_bvec(&con->v2.in_iter, READ, &con->v2.in_bvec, 1, bv->bv_len); + iov_iter_bvec(&con->v2.in_iter, ITER_DEST, &con->v2.in_bvec, 1, bv->bv_len); } static void set_in_skip(struct ceph_connection *con, int len) @@ -241,7 +241,7 @@ static void set_in_skip(struct ceph_connection *con, int len) WARN_ON(iov_iter_count(&con->v2.in_iter)); dout("%s con %p len %d\n", __func__, con, len); - iov_iter_discard(&con->v2.in_iter, READ, len); + iov_iter_discard(&con->v2.in_iter, ITER_DEST, len); } static void add_out_kvec(struct ceph_connection *con, void *buf, int len) @@ -265,7 +265,7 @@ static void reset_out_kvecs(struct ceph_connection *con) con->v2.out_kvec_cnt = 0; - iov_iter_kvec(&con->v2.out_iter, WRITE, con->v2.out_kvecs, 0, 0); + iov_iter_kvec(&con->v2.out_iter, ITER_SOURCE, con->v2.out_kvecs, 0, 0); con->v2.out_iter_sendpage = false; } @@ -277,7 +277,7 @@ static void set_out_bvec(struct ceph_connection *con, const struct bio_vec *bv, con->v2.out_bvec = *bv; con->v2.out_iter_sendpage = zerocopy; - iov_iter_bvec(&con->v2.out_iter, WRITE, &con->v2.out_bvec, 1, + iov_iter_bvec(&con->v2.out_iter, ITER_SOURCE, &con->v2.out_bvec, 1, con->v2.out_bvec.bv_len); } @@ -290,7 +290,7 @@ static void set_out_bvec_zero(struct ceph_connection *con) con->v2.out_bvec.bv_offset = 0; con->v2.out_bvec.bv_len = min(con->v2.out_zero, (int)PAGE_SIZE); con->v2.out_iter_sendpage = true; - iov_iter_bvec(&con->v2.out_iter, WRITE, &con->v2.out_bvec, 1, + iov_iter_bvec(&con->v2.out_iter, ITER_SOURCE, &con->v2.out_bvec, 1, con->v2.out_bvec.bv_len); } diff --git a/net/ceph/mon_client.c b/net/ceph/mon_client.c index db60217f911b..faabad6603db 100644 --- a/net/ceph/mon_client.c +++ b/net/ceph/mon_client.c @@ -222,7 +222,7 @@ static void pick_new_mon(struct ceph_mon_client *monc) max--; } - n = prandom_u32_max(max); + n = get_random_u32_below(max); if (o >= 0 && n >= o) n++; diff --git a/net/ceph/osd_client.c b/net/ceph/osd_client.c index 4e4f1e4bc265..11c04e7d928e 100644 --- a/net/ceph/osd_client.c +++ b/net/ceph/osd_client.c @@ -1479,7 +1479,7 @@ static bool target_should_be_paused(struct ceph_osd_client *osdc, static int pick_random_replica(const struct ceph_osds *acting) { - int i = prandom_u32_max(acting->size); + int i = get_random_u32_below(acting->size); dout("%s picked osd%d, primary osd%d\n", __func__, acting->osds[i], acting->primary); diff --git a/net/compat.c b/net/compat.c index 385f04a6be2f..161b7bea1f62 100644 --- a/net/compat.c +++ b/net/compat.c @@ -95,7 +95,8 @@ int get_compat_msghdr(struct msghdr *kmsg, if (err) return err; - err = import_iovec(save_addr ? READ : WRITE, compat_ptr(msg.msg_iov), msg.msg_iovlen, + err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE, + compat_ptr(msg.msg_iov), msg.msg_iovlen, UIO_FASTIOV, iov, &kmsg->msg_iter); return err < 0 ? err : 0; } diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c index 94374d529ea4..bb378c33f542 100644 --- a/net/core/bpf_sk_storage.c +++ b/net/core/bpf_sk_storage.c @@ -48,10 +48,8 @@ static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map) /* Called by __sk_destruct() & bpf_sk_storage_clone() */ void bpf_sk_storage_free(struct sock *sk) { - struct bpf_local_storage_elem *selem; struct bpf_local_storage *sk_storage; bool free_sk_storage = false; - struct hlist_node *n; rcu_read_lock(); sk_storage = rcu_dereference(sk->sk_bpf_storage); @@ -60,24 +58,8 @@ void bpf_sk_storage_free(struct sock *sk) return; } - /* Netiher the bpf_prog nor the bpf-map's syscall - * could be modifying the sk_storage->list now. - * Thus, no elem can be added-to or deleted-from the - * sk_storage->list by the bpf_prog or by the bpf-map's syscall. - * - * It is racing with bpf_local_storage_map_free() alone - * when unlinking elem from the sk_storage->list and - * the map's bucket->list. - */ raw_spin_lock_bh(&sk_storage->lock); - hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) { - /* Always unlink from map before unlinking from - * sk_storage. - */ - bpf_selem_unlink_map(selem); - free_sk_storage = bpf_selem_unlink_storage_nolock( - sk_storage, selem, true, false); - } + free_sk_storage = bpf_local_storage_unlink_nolock(sk_storage); raw_spin_unlock_bh(&sk_storage->lock); rcu_read_unlock(); @@ -87,23 +69,12 @@ void bpf_sk_storage_free(struct sock *sk) static void bpf_sk_storage_map_free(struct bpf_map *map) { - struct bpf_local_storage_map *smap; - - smap = (struct bpf_local_storage_map *)map; - bpf_local_storage_cache_idx_free(&sk_cache, smap->cache_idx); - bpf_local_storage_map_free(smap, NULL); + bpf_local_storage_map_free(map, &sk_cache, NULL); } static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr) { - struct bpf_local_storage_map *smap; - - smap = bpf_local_storage_map_alloc(attr); - if (IS_ERR(smap)) - return ERR_CAST(smap); - - smap->cache_idx = bpf_local_storage_cache_idx_get(&sk_cache); - return &smap->map; + return bpf_local_storage_map_alloc(attr, &sk_cache); } static int notsupp_get_next_key(struct bpf_map *map, void *key, @@ -176,7 +147,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk, if (!copy_selem) return NULL; - if (map_value_has_spin_lock(&smap->map)) + if (btf_record_has_field(smap->map.record, BPF_SPIN_LOCK)) copy_map_value_locked(&smap->map, SDATA(copy_selem)->data, SDATA(selem)->data, true); else @@ -339,7 +310,6 @@ bpf_sk_storage_ptr(void *owner) return &sk->sk_bpf_storage; } -BTF_ID_LIST_SINGLE(sk_storage_map_btf_ids, struct, bpf_local_storage_map) const struct bpf_map_ops sk_storage_map_ops = { .map_meta_equal = bpf_map_meta_equal, .map_alloc_check = bpf_local_storage_map_alloc_check, @@ -350,7 +320,7 @@ const struct bpf_map_ops sk_storage_map_ops = { .map_update_elem = bpf_fd_sk_storage_update_elem, .map_delete_elem = bpf_fd_sk_storage_delete_elem, .map_check_btf = bpf_local_storage_map_check_btf, - .map_btf_id = &sk_storage_map_btf_ids[0], + .map_btf_id = &bpf_local_storage_map_btf_id[0], .map_local_storage_charge = bpf_sk_storage_charge, .map_local_storage_uncharge = bpf_sk_storage_uncharge, .map_owner_storage_ptr = bpf_sk_storage_ptr, @@ -595,7 +565,7 @@ static int diag_get(struct bpf_local_storage_data *sdata, struct sk_buff *skb) if (!nla_value) goto errout; - if (map_value_has_spin_lock(&smap->map)) + if (btf_record_has_field(smap->map.record, BPF_SPIN_LOCK)) copy_map_value_locked(&smap->map, nla_data(nla_value), sdata->data, true); else diff --git a/net/core/dev.c b/net/core/dev.c index 3be256051e99..b76fb37b381e 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -1163,22 +1163,6 @@ int dev_change_name(struct net_device *dev, const char *newname) net = dev_net(dev); - /* Some auto-enslaved devices e.g. failover slaves are - * special, as userspace might rename the device after - * the interface had been brought up and running since - * the point kernel initiated auto-enslavement. Allow - * live name change even when these slave devices are - * up and running. - * - * Typically, users of these auto-enslaving devices - * don't actually care about slave name change, as - * they are supposed to operate on master interface - * directly. - */ - if (dev->flags & IFF_UP && - likely(!(dev->priv_flags & IFF_LIVE_RENAME_OK))) - return -EBUSY; - down_write(&devnet_rename_sem); if (strncmp(newname, dev->name, IFNAMSIZ) == 0) { @@ -1195,7 +1179,8 @@ int dev_change_name(struct net_device *dev, const char *newname) } if (oldname[0] && !strchr(oldname, '%')) - netdev_info(dev, "renamed from %s\n", oldname); + netdev_info(dev, "renamed from %s%s\n", oldname, + dev->flags & IFF_UP ? " (while UP)" : ""); old_assign_type = dev->name_assign_type; dev->name_assign_type = NET_NAME_RENAMED; @@ -1333,7 +1318,7 @@ void netdev_state_change(struct net_device *dev) call_netdevice_notifiers_info(NETDEV_CHANGE, &change_info.info); - rtmsg_ifinfo(RTM_NEWLINK, dev, 0, GFP_KERNEL); + rtmsg_ifinfo(RTM_NEWLINK, dev, 0, GFP_KERNEL, 0, NULL); } } EXPORT_SYMBOL(netdev_state_change); @@ -1469,7 +1454,7 @@ int dev_open(struct net_device *dev, struct netlink_ext_ack *extack) if (ret < 0) return ret; - rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP|IFF_RUNNING, GFP_KERNEL); + rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP | IFF_RUNNING, GFP_KERNEL, 0, NULL); call_netdevice_notifiers(NETDEV_UP, dev); return ret; @@ -1541,7 +1526,7 @@ void dev_close_many(struct list_head *head, bool unlink) __dev_close_many(head); list_for_each_entry_safe(dev, tmp, head, close_list) { - rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP|IFF_RUNNING, GFP_KERNEL); + rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP | IFF_RUNNING, GFP_KERNEL, 0, NULL); call_netdevice_notifiers(NETDEV_DOWN, dev); if (unlink) list_del_init(&dev->close_list); @@ -1621,10 +1606,10 @@ const char *netdev_cmd_to_name(enum netdev_cmd cmd) N(UP) N(DOWN) N(REBOOT) N(CHANGE) N(REGISTER) N(UNREGISTER) N(CHANGEMTU) N(CHANGEADDR) N(GOING_DOWN) N(CHANGENAME) N(FEAT_CHANGE) N(BONDING_FAILOVER) N(PRE_UP) N(PRE_TYPE_CHANGE) N(POST_TYPE_CHANGE) - N(POST_INIT) N(RELEASE) N(NOTIFY_PEERS) N(JOIN) N(CHANGEUPPER) - N(RESEND_IGMP) N(PRECHANGEMTU) N(CHANGEINFODATA) N(BONDING_INFO) - N(PRECHANGEUPPER) N(CHANGELOWERSTATE) N(UDP_TUNNEL_PUSH_INFO) - N(UDP_TUNNEL_DROP_INFO) N(CHANGE_TX_QUEUE_LEN) + N(POST_INIT) N(PRE_UNINIT) N(RELEASE) N(NOTIFY_PEERS) N(JOIN) + N(CHANGEUPPER) N(RESEND_IGMP) N(PRECHANGEMTU) N(CHANGEINFODATA) + N(BONDING_INFO) N(PRECHANGEUPPER) N(CHANGELOWERSTATE) + N(UDP_TUNNEL_PUSH_INFO) N(UDP_TUNNEL_DROP_INFO) N(CHANGE_TX_QUEUE_LEN) N(CVLAN_FILTER_PUSH_INFO) N(CVLAN_FILTER_DROP_INFO) N(SVLAN_FILTER_PUSH_INFO) N(SVLAN_FILTER_DROP_INFO) N(PRE_CHANGEADDR) N(OFFLOAD_XSTATS_ENABLE) N(OFFLOAD_XSTATS_DISABLE) @@ -1876,6 +1861,22 @@ int unregister_netdevice_notifier_net(struct net *net, } EXPORT_SYMBOL(unregister_netdevice_notifier_net); +static void __move_netdevice_notifier_net(struct net *src_net, + struct net *dst_net, + struct notifier_block *nb) +{ + __unregister_netdevice_notifier_net(src_net, nb); + __register_netdevice_notifier_net(dst_net, nb, true); +} + +void move_netdevice_notifier_net(struct net *src_net, struct net *dst_net, + struct notifier_block *nb) +{ + rtnl_lock(); + __move_netdevice_notifier_net(src_net, dst_net, nb); + rtnl_unlock(); +} + int register_netdevice_notifier_dev_net(struct net_device *dev, struct notifier_block *nb, struct netdev_net_notifier *nn) @@ -1912,10 +1913,8 @@ static void move_netdevice_notifiers_dev_net(struct net_device *dev, { struct netdev_net_notifier *nn; - list_for_each_entry(nn, &dev->net_notifier_list, list) { - __unregister_netdevice_notifier_net(dev_net(dev), nn->nb); - __register_netdevice_notifier_net(net, nn->nb, true); - } + list_for_each_entry(nn, &dev->net_notifier_list, list) + __move_netdevice_notifier_net(dev_net(dev), net, nn->nb); } /** @@ -2074,13 +2073,10 @@ static DECLARE_WORK(netstamp_work, netstamp_clear); void net_enable_timestamp(void) { #ifdef CONFIG_JUMP_LABEL - int wanted; + int wanted = atomic_read(&netstamp_wanted); - while (1) { - wanted = atomic_read(&netstamp_wanted); - if (wanted <= 0) - break; - if (atomic_cmpxchg(&netstamp_wanted, wanted, wanted + 1) == wanted) + while (wanted > 0) { + if (atomic_try_cmpxchg(&netstamp_wanted, &wanted, wanted + 1)) return; } atomic_inc(&netstamp_needed_deferred); @@ -2094,13 +2090,10 @@ EXPORT_SYMBOL(net_enable_timestamp); void net_disable_timestamp(void) { #ifdef CONFIG_JUMP_LABEL - int wanted; + int wanted = atomic_read(&netstamp_wanted); - while (1) { - wanted = atomic_read(&netstamp_wanted); - if (wanted <= 1) - break; - if (atomic_cmpxchg(&netstamp_wanted, wanted, wanted - 1) == wanted) + while (wanted > 1) { + if (atomic_try_cmpxchg(&netstamp_wanted, &wanted, wanted - 1)) return; } atomic_dec(&netstamp_needed_deferred); @@ -5986,10 +5979,9 @@ EXPORT_SYMBOL(__napi_schedule); */ bool napi_schedule_prep(struct napi_struct *n) { - unsigned long val, new; + unsigned long new, val = READ_ONCE(n->state); do { - val = READ_ONCE(n->state); if (unlikely(val & NAPIF_STATE_DISABLE)) return false; new = val | NAPIF_STATE_SCHED; @@ -6002,7 +5994,7 @@ bool napi_schedule_prep(struct napi_struct *n) */ new |= (val & NAPIF_STATE_SCHED) / NAPIF_STATE_SCHED * NAPIF_STATE_MISSED; - } while (cmpxchg(&n->state, val, new) != val); + } while (!try_cmpxchg(&n->state, &val, new)); return !(val & NAPIF_STATE_SCHED); } @@ -6070,9 +6062,8 @@ bool napi_complete_done(struct napi_struct *n, int work_done) local_irq_restore(flags); } + val = READ_ONCE(n->state); do { - val = READ_ONCE(n->state); - WARN_ON_ONCE(!(val & NAPIF_STATE_SCHED)); new = val & ~(NAPIF_STATE_MISSED | NAPIF_STATE_SCHED | @@ -6085,7 +6076,7 @@ bool napi_complete_done(struct napi_struct *n, int work_done) */ new |= (val & NAPIF_STATE_MISSED) / NAPIF_STATE_MISSED * NAPIF_STATE_SCHED; - } while (cmpxchg(&n->state, val, new) != val); + } while (!try_cmpxchg(&n->state, &val, new)); if (unlikely(val & NAPIF_STATE_MISSED)) { __napi_schedule(n); @@ -6406,19 +6397,16 @@ void napi_disable(struct napi_struct *n) might_sleep(); set_bit(NAPI_STATE_DISABLE, &n->state); - for ( ; ; ) { - val = READ_ONCE(n->state); - if (val & (NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC)) { + val = READ_ONCE(n->state); + do { + while (val & (NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC)) { usleep_range(20, 200); - continue; + val = READ_ONCE(n->state); } new = val | NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC; new &= ~(NAPIF_STATE_THREADED | NAPIF_STATE_PREFER_BUSY_POLL); - - if (cmpxchg(&n->state, val, new) == val) - break; - } + } while (!try_cmpxchg(&n->state, &val, new)); hrtimer_cancel(&n->timer); @@ -6435,16 +6423,15 @@ EXPORT_SYMBOL(napi_disable); */ void napi_enable(struct napi_struct *n) { - unsigned long val, new; + unsigned long new, val = READ_ONCE(n->state); do { - val = READ_ONCE(n->state); BUG_ON(!test_bit(NAPI_STATE_SCHED, &val)); new = val & ~(NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC); if (n->dev->threaded && n->thread) new |= NAPIF_STATE_THREADED; - } while (cmpxchg(&n->state, val, new) != val); + } while (!try_cmpxchg(&n->state, &val, new)); } EXPORT_SYMBOL(napi_enable); @@ -8351,7 +8338,7 @@ static int __dev_set_promiscuity(struct net_device *dev, int inc, bool notify) dev_change_rx_flags(dev, IFF_PROMISC); } if (notify) - __dev_notify_flags(dev, old_flags, IFF_PROMISC); + __dev_notify_flags(dev, old_flags, IFF_PROMISC, 0, NULL); return 0; } @@ -8406,7 +8393,7 @@ static int __dev_set_allmulti(struct net_device *dev, int inc, bool notify) dev_set_rx_mode(dev); if (notify) __dev_notify_flags(dev, old_flags, - dev->gflags ^ old_gflags); + dev->gflags ^ old_gflags, 0, NULL); } return 0; } @@ -8569,12 +8556,13 @@ int __dev_change_flags(struct net_device *dev, unsigned int flags, } void __dev_notify_flags(struct net_device *dev, unsigned int old_flags, - unsigned int gchanges) + unsigned int gchanges, u32 portid, + const struct nlmsghdr *nlh) { unsigned int changes = dev->flags ^ old_flags; if (gchanges) - rtmsg_ifinfo(RTM_NEWLINK, dev, gchanges, GFP_ATOMIC); + rtmsg_ifinfo(RTM_NEWLINK, dev, gchanges, GFP_ATOMIC, portid, nlh); if (changes & IFF_UP) { if (dev->flags & IFF_UP) @@ -8616,7 +8604,7 @@ int dev_change_flags(struct net_device *dev, unsigned int flags, return ret; changes = (old_flags ^ dev->flags) | (old_gflags ^ dev->gflags); - __dev_notify_flags(dev, old_flags, changes); + __dev_notify_flags(dev, old_flags, changes, 0, NULL); return ret; } EXPORT_SYMBOL(dev_change_flags); @@ -8822,7 +8810,7 @@ EXPORT_SYMBOL(dev_set_mac_address_user); int dev_get_mac_address(struct sockaddr *sa, struct net *net, char *dev_name) { - size_t size = sizeof(sa->sa_data); + size_t size = sizeof(sa->sa_data_min); struct net_device *dev; int ret = 0; @@ -10059,7 +10047,7 @@ int register_netdevice(struct net_device *dev) dev->reg_state = ret ? NETREG_UNREGISTERED : NETREG_REGISTERED; write_unlock(&dev_base_lock); if (ret) - goto err_uninit; + goto err_uninit_notify; __netdev_update_features(dev); @@ -10101,11 +10089,13 @@ int register_netdevice(struct net_device *dev) */ if (!dev->rtnl_link_ops || dev->rtnl_link_state == RTNL_LINK_INITIALIZED) - rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL); + rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL, 0, NULL); out: return ret; +err_uninit_notify: + call_netdevice_notifiers(NETDEV_PRE_UNINIT, dev); err_uninit: if (dev->netdev_ops->ndo_uninit) dev->netdev_ops->ndo_uninit(dev); @@ -10379,24 +10369,16 @@ void netdev_run_todo(void) void netdev_stats_to_stats64(struct rtnl_link_stats64 *stats64, const struct net_device_stats *netdev_stats) { -#if BITS_PER_LONG == 64 - BUILD_BUG_ON(sizeof(*stats64) < sizeof(*netdev_stats)); - memcpy(stats64, netdev_stats, sizeof(*netdev_stats)); - /* zero out counters that only exist in rtnl_link_stats64 */ - memset((char *)stats64 + sizeof(*netdev_stats), 0, - sizeof(*stats64) - sizeof(*netdev_stats)); -#else - size_t i, n = sizeof(*netdev_stats) / sizeof(unsigned long); - const unsigned long *src = (const unsigned long *)netdev_stats; + size_t i, n = sizeof(*netdev_stats) / sizeof(atomic_long_t); + const atomic_long_t *src = (atomic_long_t *)netdev_stats; u64 *dst = (u64 *)stats64; BUILD_BUG_ON(n > sizeof(*stats64) / sizeof(u64)); for (i = 0; i < n; i++) - dst[i] = src[i]; + dst[i] = atomic_long_read(&src[i]); /* zero out counters that only exist in rtnl_link_stats64 */ memset((char *)stats64 + n * sizeof(u64), 0, sizeof(*stats64) - n * sizeof(u64)); -#endif } EXPORT_SYMBOL(netdev_stats_to_stats64); @@ -10477,12 +10459,12 @@ void dev_fetch_sw_netstats(struct rtnl_link_stats64 *s, stats = per_cpu_ptr(netstats, cpu); do { - start = u64_stats_fetch_begin_irq(&stats->syncp); + start = u64_stats_fetch_begin(&stats->syncp); rx_packets = u64_stats_read(&stats->rx_packets); rx_bytes = u64_stats_read(&stats->rx_bytes); tx_packets = u64_stats_read(&stats->tx_packets); tx_bytes = u64_stats_read(&stats->tx_bytes); - } while (u64_stats_fetch_retry_irq(&stats->syncp, start)); + } while (u64_stats_fetch_retry(&stats->syncp, start)); s->rx_packets += rx_packets; s->rx_bytes += rx_bytes; @@ -10535,6 +10517,22 @@ void netdev_set_default_ethtool_ops(struct net_device *dev, } EXPORT_SYMBOL_GPL(netdev_set_default_ethtool_ops); +/** + * netdev_sw_irq_coalesce_default_on() - enable SW IRQ coalescing by default + * @dev: netdev to enable the IRQ coalescing on + * + * Sets a conservative default for SW IRQ coalescing. Users can use + * sysfs attributes to override the default values. + */ +void netdev_sw_irq_coalesce_default_on(struct net_device *dev) +{ + WARN_ON(dev->reg_state == NETREG_REGISTERED); + + dev->gro_flush_timeout = 20000; + dev->napi_defer_hard_irqs = 1; +} +EXPORT_SYMBOL_GPL(netdev_sw_irq_coalesce_default_on); + void netdev_freemem(struct net_device *dev) { char *addr = (char *)dev - dev->padded; @@ -10780,14 +10778,8 @@ void unregister_netdevice_queue(struct net_device *dev, struct list_head *head) } EXPORT_SYMBOL(unregister_netdevice_queue); -/** - * unregister_netdevice_many - unregister many devices - * @head: list of devices - * - * Note: As most callers use a stack allocated list_head, - * we force a list_del() to make sure stack wont be corrupted later. - */ -void unregister_netdevice_many(struct list_head *head) +void unregister_netdevice_many_notify(struct list_head *head, + u32 portid, const struct nlmsghdr *nlh) { struct net_device *dev, *tmp; LIST_HEAD(close_head); @@ -10849,7 +10841,8 @@ void unregister_netdevice_many(struct list_head *head) if (!dev->rtnl_link_ops || dev->rtnl_link_state == RTNL_LINK_INITIALIZED) skb = rtmsg_ifinfo_build_skb(RTM_DELLINK, dev, ~0U, 0, - GFP_KERNEL, NULL, 0); + GFP_KERNEL, NULL, 0, + portid, nlmsg_seq(nlh)); /* * Flush the unicast and multicast chains @@ -10860,11 +10853,13 @@ void unregister_netdevice_many(struct list_head *head) netdev_name_node_alt_flush(dev); netdev_name_node_free(dev->name_node); + call_netdevice_notifiers(NETDEV_PRE_UNINIT, dev); + if (dev->netdev_ops->ndo_uninit) dev->netdev_ops->ndo_uninit(dev); if (skb) - rtmsg_ifinfo_send(skb, dev, GFP_KERNEL); + rtmsg_ifinfo_send(skb, dev, GFP_KERNEL, portid, nlh); /* Notifier chain MUST detach us all upper devices. */ WARN_ON(netdev_has_any_upper_dev(dev)); @@ -10887,6 +10882,18 @@ void unregister_netdevice_many(struct list_head *head) list_del(head); } + +/** + * unregister_netdevice_many - unregister many devices + * @head: list of devices + * + * Note: As most callers use a stack allocated list_head, + * we force a list_del() to make sure stack wont be corrupted later. + */ +void unregister_netdevice_many(struct list_head *head) +{ + unregister_netdevice_many_notify(head, 0, NULL); +} EXPORT_SYMBOL(unregister_netdevice_many); /** @@ -11042,7 +11049,7 @@ int __dev_change_net_namespace(struct net_device *dev, struct net *net, * Prevent userspace races by waiting until the network * device is fully setup before sending notifications. */ - rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL); + rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL, 0, NULL); synchronize_net(); err = 0; diff --git a/net/core/dev.h b/net/core/dev.h index cbb8a925175a..814ed5b7b960 100644 --- a/net/core/dev.h +++ b/net/core/dev.h @@ -88,6 +88,13 @@ int dev_change_carrier(struct net_device *dev, bool new_carrier); void __dev_set_rx_mode(struct net_device *dev); +void __dev_notify_flags(struct net_device *dev, unsigned int old_flags, + unsigned int gchanges, u32 portid, + const struct nlmsghdr *nlh); + +void unregister_netdevice_many_notify(struct list_head *head, + u32 portid, const struct nlmsghdr *nlh); + static inline void netif_set_gso_max_size(struct net_device *dev, unsigned int size) { diff --git a/net/core/dev_addr_lists_test.c b/net/core/dev_addr_lists_test.c index 049cfbc58aa9..90e7e3811ae7 100644 --- a/net/core/dev_addr_lists_test.c +++ b/net/core/dev_addr_lists_test.c @@ -71,11 +71,11 @@ static void dev_addr_test_basic(struct kunit *test) memset(addr, 2, sizeof(addr)); eth_hw_addr_set(netdev, addr); - KUNIT_EXPECT_EQ(test, 0, memcmp(netdev->dev_addr, addr, sizeof(addr))); + KUNIT_EXPECT_MEMEQ(test, netdev->dev_addr, addr, sizeof(addr)); memset(addr, 3, sizeof(addr)); dev_addr_set(netdev, addr); - KUNIT_EXPECT_EQ(test, 0, memcmp(netdev->dev_addr, addr, sizeof(addr))); + KUNIT_EXPECT_MEMEQ(test, netdev->dev_addr, addr, sizeof(addr)); } static void dev_addr_test_sync_one(struct kunit *test) diff --git a/net/core/dev_ioctl.c b/net/core/dev_ioctl.c index 7674bb9f3076..5cdbfbf9a7dc 100644 --- a/net/core/dev_ioctl.c +++ b/net/core/dev_ioctl.c @@ -342,7 +342,7 @@ static int dev_ifsioc(struct net *net, struct ifreq *ifr, void __user *data, if (ifr->ifr_hwaddr.sa_family != dev->type) return -EINVAL; memcpy(dev->broadcast, ifr->ifr_hwaddr.sa_data, - min(sizeof(ifr->ifr_hwaddr.sa_data), + min(sizeof(ifr->ifr_hwaddr.sa_data_min), (size_t)dev->addr_len)); call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); return 0; diff --git a/net/core/devlink.c b/net/core/devlink.c index 89baa7c0938b..032d6d0a5ce6 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -41,7 +41,7 @@ struct devlink_dev_stats { struct devlink { u32 index; - struct list_head port_list; + struct xarray ports; struct list_head rate_list; struct list_head sb_list; struct list_head dpipe_table_list; @@ -71,6 +71,7 @@ struct devlink { refcount_t refcount; struct completion comp; struct rcu_head rcu; + struct notifier_block netdevice_nb; char priv[] __aligned(NETDEV_ALIGN); }; @@ -194,11 +195,16 @@ EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwmsg); EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwerr); EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_trap_report); +#define DEVLINK_PORT_FN_CAPS_VALID_MASK \ + (_BITUL(__DEVLINK_PORT_FN_ATTR_CAPS_MAX) - 1) + static const struct nla_policy devlink_function_nl_policy[DEVLINK_PORT_FUNCTION_ATTR_MAX + 1] = { [DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR] = { .type = NLA_BINARY }, [DEVLINK_PORT_FN_ATTR_STATE] = NLA_POLICY_RANGE(NLA_U8, DEVLINK_PORT_FN_STATE_INACTIVE, DEVLINK_PORT_FN_STATE_ACTIVE), + [DEVLINK_PORT_FN_ATTR_CAPS] = + NLA_POLICY_BITFIELD32(DEVLINK_PORT_FN_CAPS_VALID_MASK), }; static const struct nla_policy devlink_selftest_nl_policy[DEVLINK_ATTR_SELFTEST_ID_MAX + 1] = { @@ -381,19 +387,7 @@ static struct devlink *devlink_get_from_attrs(struct net *net, static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink, unsigned int port_index) { - struct devlink_port *devlink_port; - - list_for_each_entry(devlink_port, &devlink->port_list, list) { - if (devlink_port->index == port_index) - return devlink_port; - } - return NULL; -} - -static bool devlink_port_index_exists(struct devlink *devlink, - unsigned int port_index) -{ - return devlink_port_get_by_index(devlink, port_index); + return xa_load(&devlink->ports, port_index); } static struct devlink_port *devlink_port_get_from_attrs(struct devlink *devlink, @@ -691,6 +685,87 @@ devlink_sb_tc_index_get_from_attrs(struct devlink_sb *devlink_sb, return 0; } +static void devlink_port_fn_cap_fill(struct nla_bitfield32 *caps, + u32 cap, bool is_enable) +{ + caps->selector |= cap; + if (is_enable) + caps->value |= cap; +} + +static int devlink_port_fn_roce_fill(const struct devlink_ops *ops, + struct devlink_port *devlink_port, + struct nla_bitfield32 *caps, + struct netlink_ext_ack *extack) +{ + bool is_enable; + int err; + + if (!ops->port_fn_roce_get) + return 0; + + err = ops->port_fn_roce_get(devlink_port, &is_enable, extack); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + + devlink_port_fn_cap_fill(caps, DEVLINK_PORT_FN_CAP_ROCE, is_enable); + return 0; +} + +static int devlink_port_fn_migratable_fill(const struct devlink_ops *ops, + struct devlink_port *devlink_port, + struct nla_bitfield32 *caps, + struct netlink_ext_ack *extack) +{ + bool is_enable; + int err; + + if (!ops->port_fn_migratable_get || + devlink_port->attrs.flavour != DEVLINK_PORT_FLAVOUR_PCI_VF) + return 0; + + err = ops->port_fn_migratable_get(devlink_port, &is_enable, extack); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + + devlink_port_fn_cap_fill(caps, DEVLINK_PORT_FN_CAP_MIGRATABLE, is_enable); + return 0; +} + +static int devlink_port_fn_caps_fill(const struct devlink_ops *ops, + struct devlink_port *devlink_port, + struct sk_buff *msg, + struct netlink_ext_ack *extack, + bool *msg_updated) +{ + struct nla_bitfield32 caps = {}; + int err; + + err = devlink_port_fn_roce_fill(ops, devlink_port, &caps, extack); + if (err) + return err; + + err = devlink_port_fn_migratable_fill(ops, devlink_port, &caps, extack); + if (err) + return err; + + if (!caps.selector) + return 0; + err = nla_put_bitfield32(msg, DEVLINK_PORT_FN_ATTR_CAPS, caps.value, + caps.selector); + if (err) + return err; + + *msg_updated = true; + return 0; +} + static int devlink_sb_tc_index_get_from_info(struct devlink_sb *devlink_sb, struct genl_info *info, @@ -769,7 +844,7 @@ devlink_region_snapshot_get_by_id(struct devlink_region *region, u32 id) #define DEVLINK_NL_FLAG_NEED_RATE_NODE BIT(3) #define DEVLINK_NL_FLAG_NEED_LINECARD BIT(4) -static int devlink_nl_pre_doit(const struct genl_ops *ops, +static int devlink_nl_pre_doit(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { struct devlink_linecard *linecard; @@ -827,7 +902,7 @@ unlock: return err; } -static void devlink_nl_post_doit(const struct genl_ops *ops, +static void devlink_nl_post_doit(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { struct devlink_linecard *linecard; @@ -879,6 +954,24 @@ nla_put_failure: return -EMSGSIZE; } +int devlink_nl_port_handle_fill(struct sk_buff *msg, struct devlink_port *devlink_port) +{ + if (devlink_nl_put_handle(msg, devlink_port->devlink)) + return -EMSGSIZE; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index)) + return -EMSGSIZE; + return 0; +} + +size_t devlink_nl_port_handle_size(struct devlink_port *devlink_port) +{ + struct devlink *devlink = devlink_port->devlink; + + return nla_total_size(strlen(devlink->dev->bus->name) + 1) /* DEVLINK_ATTR_BUS_NAME */ + + nla_total_size(strlen(dev_name(devlink->dev)) + 1) /* DEVLINK_ATTR_DEV_NAME */ + + nla_total_size(4); /* DEVLINK_ATTR_PORT_INDEX */ +} + struct devlink_reload_combination { enum devlink_reload_action action; enum devlink_reload_limit limit; @@ -1184,6 +1277,14 @@ static int devlink_nl_rate_fill(struct sk_buff *msg, devlink_rate->tx_max, DEVLINK_ATTR_PAD)) goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_RATE_TX_PRIORITY, + devlink_rate->tx_priority)) + goto nla_put_failure; + + if (nla_put_u32(msg, DEVLINK_ATTR_RATE_TX_WEIGHT, + devlink_rate->tx_weight)) + goto nla_put_failure; + if (devlink_rate->parent) if (nla_put_string(msg, DEVLINK_ATTR_RATE_PARENT_NODE_NAME, devlink_rate->parent->name)) @@ -1249,6 +1350,51 @@ static int devlink_port_fn_state_fill(const struct devlink_ops *ops, } static int +devlink_port_fn_mig_set(struct devlink_port *devlink_port, bool enable, + struct netlink_ext_ack *extack) +{ + const struct devlink_ops *ops = devlink_port->devlink->ops; + + return ops->port_fn_migratable_set(devlink_port, enable, extack); +} + +static int +devlink_port_fn_roce_set(struct devlink_port *devlink_port, bool enable, + struct netlink_ext_ack *extack) +{ + const struct devlink_ops *ops = devlink_port->devlink->ops; + + return ops->port_fn_roce_set(devlink_port, enable, extack); +} + +static int devlink_port_fn_caps_set(struct devlink_port *devlink_port, + const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nla_bitfield32 caps; + u32 caps_value; + int err; + + caps = nla_get_bitfield32(attr); + caps_value = caps.value & caps.selector; + if (caps.selector & DEVLINK_PORT_FN_CAP_ROCE) { + err = devlink_port_fn_roce_set(devlink_port, + caps_value & DEVLINK_PORT_FN_CAP_ROCE, + extack); + if (err) + return err; + } + if (caps.selector & DEVLINK_PORT_FN_CAP_MIGRATABLE) { + err = devlink_port_fn_mig_set(devlink_port, caps_value & + DEVLINK_PORT_FN_CAP_MIGRATABLE, + extack); + if (err) + return err; + } + return 0; +} + +static int devlink_nl_port_function_attrs_put(struct sk_buff *msg, struct devlink_port *port, struct netlink_ext_ack *extack) { @@ -1266,6 +1412,10 @@ devlink_nl_port_function_attrs_put(struct sk_buff *msg, struct devlink_port *por &msg_updated); if (err) goto out; + err = devlink_port_fn_caps_fill(ops, port, msg, extack, + &msg_updated); + if (err) + goto out; err = devlink_port_fn_state_fill(ops, port, msg, extack, &msg_updated); out: if (err || !msg_updated) @@ -1292,8 +1442,6 @@ static int devlink_nl_port_fill(struct sk_buff *msg, if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index)) goto nla_put_failure; - /* Hold rtnl lock while accessing port's netdev attributes. */ - rtnl_lock(); spin_lock_bh(&devlink_port->type_lock); if (nla_put_u16(msg, DEVLINK_ATTR_PORT_TYPE, devlink_port->type)) goto nla_put_failure_type_locked; @@ -1302,18 +1450,15 @@ static int devlink_nl_port_fill(struct sk_buff *msg, devlink_port->desired_type)) goto nla_put_failure_type_locked; if (devlink_port->type == DEVLINK_PORT_TYPE_ETH) { - struct net *net = devlink_net(devlink_port->devlink); - struct net_device *netdev = devlink_port->type_dev; - - if (netdev && net_eq(net, dev_net(netdev)) && + if (devlink_port->type_eth.netdev && (nla_put_u32(msg, DEVLINK_ATTR_PORT_NETDEV_IFINDEX, - netdev->ifindex) || + devlink_port->type_eth.ifindex) || nla_put_string(msg, DEVLINK_ATTR_PORT_NETDEV_NAME, - netdev->name))) + devlink_port->type_eth.ifname))) goto nla_put_failure_type_locked; } if (devlink_port->type == DEVLINK_PORT_TYPE_IB) { - struct ib_device *ibdev = devlink_port->type_dev; + struct ib_device *ibdev = devlink_port->type_ib.ibdev; if (ibdev && nla_put_string(msg, DEVLINK_ATTR_PORT_IBDEV_NAME, @@ -1321,7 +1466,6 @@ static int devlink_nl_port_fill(struct sk_buff *msg, goto nla_put_failure_type_locked; } spin_unlock_bh(&devlink_port->type_lock); - rtnl_unlock(); if (devlink_nl_port_attrs_put(msg, devlink_port)) goto nla_put_failure; if (devlink_nl_port_function_attrs_put(msg, devlink_port, extack)) @@ -1336,7 +1480,6 @@ static int devlink_nl_port_fill(struct sk_buff *msg, nla_put_failure_type_locked: spin_unlock_bh(&devlink_port->type_lock); - rtnl_unlock(); nla_put_failure: genlmsg_cancel(msg, hdr); return -EMSGSIZE; @@ -1505,10 +1648,13 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg, continue; } + devl_lock(devlink); err = devlink_nl_fill(msg, devlink, DEVLINK_CMD_NEW, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI); + devl_unlock(devlink); devlink_put(devlink); + if (err) goto out; idx++; @@ -1545,14 +1691,14 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg, { struct devlink *devlink; struct devlink_port *devlink_port; + unsigned long index, port_index; int start = cb->args[0]; - unsigned long index; int idx = 0; int err; devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); - list_for_each_entry(devlink_port, &devlink->port_list, list) { + xa_for_each(&devlink->ports, port_index, devlink_port) { if (idx < start) { idx++; continue; @@ -1624,11 +1770,6 @@ static int devlink_port_function_hw_addr_set(struct devlink_port *port, } } - if (!ops->port_function_hw_addr_set) { - NL_SET_ERR_MSG_MOD(extack, "Port doesn't support function attributes"); - return -EOPNOTSUPP; - } - return ops->port_function_hw_addr_set(port, hw_addr, hw_addr_len, extack); } @@ -1642,12 +1783,52 @@ static int devlink_port_fn_state_set(struct devlink_port *port, state = nla_get_u8(attr); ops = port->devlink->ops; - if (!ops->port_fn_state_set) { - NL_SET_ERR_MSG_MOD(extack, - "Function does not support state setting"); + return ops->port_fn_state_set(port, state, extack); +} + +static int devlink_port_function_validate(struct devlink_port *devlink_port, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + const struct devlink_ops *ops = devlink_port->devlink->ops; + struct nlattr *attr; + + if (tb[DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR] && + !ops->port_function_hw_addr_set) { + NL_SET_ERR_MSG_ATTR(extack, tb[DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR], + "Port doesn't support function attributes"); return -EOPNOTSUPP; } - return ops->port_fn_state_set(port, state, extack); + if (tb[DEVLINK_PORT_FN_ATTR_STATE] && !ops->port_fn_state_set) { + NL_SET_ERR_MSG_ATTR(extack, tb[DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR], + "Function does not support state setting"); + return -EOPNOTSUPP; + } + attr = tb[DEVLINK_PORT_FN_ATTR_CAPS]; + if (attr) { + struct nla_bitfield32 caps; + + caps = nla_get_bitfield32(attr); + if (caps.selector & DEVLINK_PORT_FN_CAP_ROCE && + !ops->port_fn_roce_set) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Port doesn't support RoCE function attribute"); + return -EOPNOTSUPP; + } + if (caps.selector & DEVLINK_PORT_FN_CAP_MIGRATABLE) { + if (!ops->port_fn_migratable_set) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Port doesn't support migratable function attribute"); + return -EOPNOTSUPP; + } + if (devlink_port->attrs.flavour != DEVLINK_PORT_FLAVOUR_PCI_VF) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "migratable function attribute supported for VFs only"); + return -EOPNOTSUPP; + } + } + } + return 0; } static int devlink_port_function_set(struct devlink_port *port, @@ -1664,12 +1845,24 @@ static int devlink_port_function_set(struct devlink_port *port, return err; } + err = devlink_port_function_validate(port, tb, extack); + if (err) + return err; + attr = tb[DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR]; if (attr) { err = devlink_port_function_hw_addr_set(port, attr, extack); if (err) return err; } + + attr = tb[DEVLINK_PORT_FN_ATTR_CAPS]; + if (attr) { + err = devlink_port_fn_caps_set(port, attr, extack); + if (err) + return err; + } + /* Keep this as the last function attribute set, so that when * multiple port function attributes are set along with state, * Those can be applied first before activating the state. @@ -1867,10 +2060,8 @@ devlink_nl_rate_parent_node_set(struct devlink_rate *devlink_rate, int err = -EOPNOTSUPP; parent = devlink_rate->parent; - if (parent && len) { - NL_SET_ERR_MSG_MOD(info->extack, "Rate object already has parent."); - return -EBUSY; - } else if (parent && !len) { + + if (parent && !len) { if (devlink_rate_is_leaf(devlink_rate)) err = ops->rate_leaf_parent_set(devlink_rate, NULL, devlink_rate->priv, NULL, @@ -1884,7 +2075,7 @@ devlink_nl_rate_parent_node_set(struct devlink_rate *devlink_rate, refcount_dec(&parent->refcnt); devlink_rate->parent = NULL; - } else if (!parent && len) { + } else if (len) { parent = devlink_rate_node_get_by_name(devlink, parent_name); if (IS_ERR(parent)) return -ENODEV; @@ -1911,6 +2102,10 @@ devlink_nl_rate_parent_node_set(struct devlink_rate *devlink_rate, if (err) return err; + if (devlink_rate->parent) + /* we're reassigning to other parent in this case */ + refcount_dec(&devlink_rate->parent->refcnt); + refcount_inc(&parent->refcnt); devlink_rate->parent = parent; } @@ -1924,6 +2119,8 @@ static int devlink_nl_rate_set(struct devlink_rate *devlink_rate, { struct nlattr *nla_parent, **attrs = info->attrs; int err = -EOPNOTSUPP; + u32 priority; + u32 weight; u64 rate; if (attrs[DEVLINK_ATTR_RATE_TX_SHARE]) { @@ -1952,6 +2149,34 @@ static int devlink_nl_rate_set(struct devlink_rate *devlink_rate, devlink_rate->tx_max = rate; } + if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY]) { + priority = nla_get_u32(attrs[DEVLINK_ATTR_RATE_TX_PRIORITY]); + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_tx_priority_set(devlink_rate, devlink_rate->priv, + priority, info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_tx_priority_set(devlink_rate, devlink_rate->priv, + priority, info->extack); + + if (err) + return err; + devlink_rate->tx_priority = priority; + } + + if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT]) { + weight = nla_get_u32(attrs[DEVLINK_ATTR_RATE_TX_WEIGHT]); + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_tx_weight_set(devlink_rate, devlink_rate->priv, + weight, info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_tx_weight_set(devlink_rate, devlink_rate->priv, + weight, info->extack); + + if (err) + return err; + devlink_rate->tx_weight = weight; + } + nla_parent = attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME]; if (nla_parent) { err = devlink_nl_rate_parent_node_set(devlink_rate, info, @@ -1983,6 +2208,18 @@ static bool devlink_rate_set_ops_supported(const struct devlink_ops *ops, NL_SET_ERR_MSG_MOD(info->extack, "Parent set isn't supported for the leafs"); return false; } + if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY] && !ops->rate_leaf_tx_priority_set) { + NL_SET_ERR_MSG_ATTR(info->extack, + attrs[DEVLINK_ATTR_RATE_TX_PRIORITY], + "TX priority set isn't supported for the leafs"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT] && !ops->rate_leaf_tx_weight_set) { + NL_SET_ERR_MSG_ATTR(info->extack, + attrs[DEVLINK_ATTR_RATE_TX_WEIGHT], + "TX weight set isn't supported for the leafs"); + return false; + } } else if (type == DEVLINK_RATE_TYPE_NODE) { if (attrs[DEVLINK_ATTR_RATE_TX_SHARE] && !ops->rate_node_tx_share_set) { NL_SET_ERR_MSG_MOD(info->extack, "TX share set isn't supported for the nodes"); @@ -1997,6 +2234,18 @@ static bool devlink_rate_set_ops_supported(const struct devlink_ops *ops, NL_SET_ERR_MSG_MOD(info->extack, "Parent set isn't supported for the nodes"); return false; } + if (attrs[DEVLINK_ATTR_RATE_TX_PRIORITY] && !ops->rate_node_tx_priority_set) { + NL_SET_ERR_MSG_ATTR(info->extack, + attrs[DEVLINK_ATTR_RATE_TX_PRIORITY], + "TX priority set isn't supported for the nodes"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_TX_WEIGHT] && !ops->rate_node_tx_weight_set) { + NL_SET_ERR_MSG_ATTR(info->extack, + attrs[DEVLINK_ATTR_RATE_TX_WEIGHT], + "TX weight set isn't supported for the nodes"); + return false; + } } else { WARN(1, "Unknown type of rate object"); return false; @@ -2810,10 +3059,11 @@ static int __sb_port_pool_get_dumpit(struct sk_buff *msg, int start, int *p_idx, { struct devlink_port *devlink_port; u16 pool_count = devlink_sb_pool_count(devlink_sb); + unsigned long port_index; u16 pool_index; int err; - list_for_each_entry(devlink_port, &devlink->port_list, list) { + xa_for_each(&devlink->ports, port_index, devlink_port) { for (pool_index = 0; pool_index < pool_count; pool_index++) { if (*p_idx < start) { (*p_idx)++; @@ -3031,10 +3281,11 @@ static int __sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, u32 portid, u32 seq) { struct devlink_port *devlink_port; + unsigned long port_index; u16 tc_index; int err; - list_for_each_entry(devlink_port, &devlink->port_list, list) { + xa_for_each(&devlink->ports, port_index, devlink_port) { for (tc_index = 0; tc_index < devlink_sb->ingress_tc_count; tc_index++) { if (*p_idx < start) { @@ -4193,9 +4444,10 @@ static int devlink_resource_put(struct devlink *devlink, struct sk_buff *skb, nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_ID, resource->id, DEVLINK_ATTR_PAD)) goto nla_put_failure; - if (resource->size != resource->size_new) - nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_NEW, - resource->size_new, DEVLINK_ATTR_PAD); + if (resource->size != resource->size_new && + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_NEW, + resource->size_new, DEVLINK_ATTR_PAD)) + goto nla_put_failure; if (devlink_resource_occ_put(resource, skb)) goto nla_put_failure; if (devlink_resource_size_params_put(resource, skb)) @@ -4490,8 +4742,11 @@ static int devlink_reload(struct devlink *devlink, struct net *dest_net, if (err) return err; - if (dest_net && !net_eq(dest_net, curr_net)) + if (dest_net && !net_eq(dest_net, curr_net)) { + move_netdevice_notifier_net(curr_net, dest_net, + &devlink->netdevice_nb); write_pnet(&devlink->_net, dest_net); + } err = devlink->ops->reload_up(devlink, action, limit, actions_performed, extack); devlink_reload_failed_set(devlink, !!err); @@ -6128,6 +6383,7 @@ static int devlink_nl_cmd_region_get_devlink_dumpit(struct sk_buff *msg, { struct devlink_region *region; struct devlink_port *port; + unsigned long port_index; int err = 0; devl_lock(devlink); @@ -6146,7 +6402,7 @@ static int devlink_nl_cmd_region_get_devlink_dumpit(struct sk_buff *msg, (*idx)++; } - list_for_each_entry(port, &devlink->port_list, list) { + xa_for_each(&devlink->ports, port_index, port) { err = devlink_nl_cmd_region_get_port_dumpit(msg, cb, port, idx, start); if (err) @@ -6352,7 +6608,6 @@ unlock: } static int devlink_nl_cmd_region_read_chunk_fill(struct sk_buff *msg, - struct devlink *devlink, u8 *chunk, u32 chunk_size, u64 addr) { @@ -6382,39 +6637,37 @@ nla_put_failure: #define DEVLINK_REGION_READ_CHUNK_SIZE 256 -static int devlink_nl_region_read_snapshot_fill(struct sk_buff *skb, - struct devlink *devlink, - struct devlink_region *region, - struct nlattr **attrs, - u64 start_offset, - u64 end_offset, - u64 *new_offset) +typedef int devlink_chunk_fill_t(void *cb_priv, u8 *chunk, u32 chunk_size, + u64 curr_offset, + struct netlink_ext_ack *extack); + +static int +devlink_nl_region_read_fill(struct sk_buff *skb, devlink_chunk_fill_t *cb, + void *cb_priv, u64 start_offset, u64 end_offset, + u64 *new_offset, struct netlink_ext_ack *extack) { - struct devlink_snapshot *snapshot; u64 curr_offset = start_offset; - u32 snapshot_id; int err = 0; + u8 *data; - *new_offset = start_offset; + /* Allocate and re-use a single buffer */ + data = kmalloc(DEVLINK_REGION_READ_CHUNK_SIZE, GFP_KERNEL); + if (!data) + return -ENOMEM; - snapshot_id = nla_get_u32(attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]); - snapshot = devlink_region_snapshot_get_by_id(region, snapshot_id); - if (!snapshot) - return -EINVAL; + *new_offset = start_offset; while (curr_offset < end_offset) { u32 data_size; - u8 *data; - if (end_offset - curr_offset < DEVLINK_REGION_READ_CHUNK_SIZE) - data_size = end_offset - curr_offset; - else - data_size = DEVLINK_REGION_READ_CHUNK_SIZE; + data_size = min_t(u32, end_offset - curr_offset, + DEVLINK_REGION_READ_CHUNK_SIZE); - data = &snapshot->data[curr_offset]; - err = devlink_nl_cmd_region_read_chunk_fill(skb, devlink, - data, data_size, - curr_offset); + err = cb(cb_priv, data, data_size, curr_offset, extack); + if (err) + break; + + err = devlink_nl_cmd_region_read_chunk_fill(skb, data, data_size, curr_offset); if (err) break; @@ -6422,21 +6675,57 @@ static int devlink_nl_region_read_snapshot_fill(struct sk_buff *skb, } *new_offset = curr_offset; + kfree(data); + return err; } +static int +devlink_region_snapshot_fill(void *cb_priv, u8 *chunk, u32 chunk_size, + u64 curr_offset, + struct netlink_ext_ack __always_unused *extack) +{ + struct devlink_snapshot *snapshot = cb_priv; + + memcpy(chunk, &snapshot->data[curr_offset], chunk_size); + + return 0; +} + +static int +devlink_region_port_direct_fill(void *cb_priv, u8 *chunk, u32 chunk_size, + u64 curr_offset, struct netlink_ext_ack *extack) +{ + struct devlink_region *region = cb_priv; + + return region->port_ops->read(region->port, region->port_ops, extack, + curr_offset, chunk_size, chunk); +} + +static int +devlink_region_direct_fill(void *cb_priv, u8 *chunk, u32 chunk_size, + u64 curr_offset, struct netlink_ext_ack *extack) +{ + struct devlink_region *region = cb_priv; + + return region->ops->read(region->devlink, region->ops, extack, + curr_offset, chunk_size, chunk); +} + static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct nlattr *chunks_attr, *region_attr, *snapshot_attr; u64 ret_offset, start_offset, end_offset = U64_MAX; struct nlattr **attrs = info->attrs; struct devlink_port *port = NULL; + devlink_chunk_fill_t *region_cb; struct devlink_region *region; - struct nlattr *chunks_attr; const char *region_name; struct devlink *devlink; unsigned int index; + void *region_cb_priv; void *hdr; int err; @@ -6448,8 +6737,8 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, devl_lock(devlink); - if (!attrs[DEVLINK_ATTR_REGION_NAME] || - !attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]) { + if (!attrs[DEVLINK_ATTR_REGION_NAME]) { + NL_SET_ERR_MSG(cb->extack, "No region name provided"); err = -EINVAL; goto out_unlock; } @@ -6464,7 +6753,8 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, } } - region_name = nla_data(attrs[DEVLINK_ATTR_REGION_NAME]); + region_attr = attrs[DEVLINK_ATTR_REGION_NAME]; + region_name = nla_data(region_attr); if (port) region = devlink_port_region_get_by_name(port, region_name); @@ -6472,10 +6762,51 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, region = devlink_region_get_by_name(devlink, region_name); if (!region) { + NL_SET_ERR_MSG_ATTR(cb->extack, region_attr, "Requested region does not exist"); err = -EINVAL; goto out_unlock; } + snapshot_attr = attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]; + if (!snapshot_attr) { + if (!nla_get_flag(attrs[DEVLINK_ATTR_REGION_DIRECT])) { + NL_SET_ERR_MSG(cb->extack, "No snapshot id provided"); + err = -EINVAL; + goto out_unlock; + } + + if (!region->ops->read) { + NL_SET_ERR_MSG(cb->extack, "Requested region does not support direct read"); + err = -EOPNOTSUPP; + goto out_unlock; + } + + if (port) + region_cb = &devlink_region_port_direct_fill; + else + region_cb = &devlink_region_direct_fill; + region_cb_priv = region; + } else { + struct devlink_snapshot *snapshot; + u32 snapshot_id; + + if (nla_get_flag(attrs[DEVLINK_ATTR_REGION_DIRECT])) { + NL_SET_ERR_MSG_ATTR(cb->extack, snapshot_attr, "Direct region read does not use snapshot"); + err = -EINVAL; + goto out_unlock; + } + + snapshot_id = nla_get_u32(snapshot_attr); + snapshot = devlink_region_snapshot_get_by_id(region, snapshot_id); + if (!snapshot) { + NL_SET_ERR_MSG_ATTR(cb->extack, snapshot_attr, "Requested snapshot does not exist"); + err = -EINVAL; + goto out_unlock; + } + region_cb = &devlink_region_snapshot_fill; + region_cb_priv = snapshot; + } + if (attrs[DEVLINK_ATTR_REGION_CHUNK_ADDR] && attrs[DEVLINK_ATTR_REGION_CHUNK_LEN]) { if (!start_offset) @@ -6524,10 +6855,9 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, goto nla_put_failure; } - err = devlink_nl_region_read_snapshot_fill(skb, devlink, - region, attrs, - start_offset, - end_offset, &ret_offset); + err = devlink_nl_region_read_fill(skb, region_cb, region_cb_priv, + start_offset, end_offset, &ret_offset, + cb->extack); if (err && err != -EMSGSIZE) goto nla_put_failure; @@ -6554,14 +6884,6 @@ out_unlock: return err; } -int devlink_info_driver_name_put(struct devlink_info_req *req, const char *name) -{ - if (!req->msg) - return 0; - return nla_put_string(req->msg, DEVLINK_ATTR_INFO_DRIVER_NAME, name); -} -EXPORT_SYMBOL_GPL(devlink_info_driver_name_put); - int devlink_info_serial_number_put(struct devlink_info_req *req, const char *sn) { if (!req->msg) @@ -6670,11 +6992,25 @@ int devlink_info_version_running_put_ext(struct devlink_info_req *req, } EXPORT_SYMBOL_GPL(devlink_info_version_running_put_ext); +static int devlink_nl_driver_info_get(struct device_driver *drv, + struct devlink_info_req *req) +{ + if (!drv) + return 0; + + if (drv->name[0]) + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_DRIVER_NAME, + drv->name); + + return 0; +} + static int devlink_nl_info_fill(struct sk_buff *msg, struct devlink *devlink, enum devlink_command cmd, u32 portid, u32 seq, int flags, struct netlink_ext_ack *extack) { + struct device *dev = devlink_to_dev(devlink); struct devlink_info_req req = {}; void *hdr; int err; @@ -6688,7 +7024,13 @@ devlink_nl_info_fill(struct sk_buff *msg, struct devlink *devlink, goto err_cancel_msg; req.msg = msg; - err = devlink->ops->info_get(devlink, &req, extack); + if (devlink->ops->info_get) { + err = devlink->ops->info_get(devlink, &req, extack); + if (err) + goto err_cancel_msg; + } + + err = devlink_nl_driver_info_get(dev->driver, &req); if (err) goto err_cancel_msg; @@ -6707,9 +7049,6 @@ static int devlink_nl_cmd_info_get_doit(struct sk_buff *skb, struct sk_buff *msg; int err; - if (!devlink->ops->info_get) - return -EOPNOTSUPP; - msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!msg) return -ENOMEM; @@ -6735,7 +7074,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, int err = 0; devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { - if (idx < start || !devlink->ops->info_get) + if (idx < start) goto inc; devl_lock(devlink); @@ -7767,8 +8106,6 @@ int devlink_health_report(struct devlink_health_reporter *reporter, return -ECANCELED; } - reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_ERROR; - if (reporter->auto_dump) { mutex_lock(&reporter->dump_lock); /* store current dump of current error, for later analysis */ @@ -7897,10 +8234,10 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, struct netlink_callback *cb) { struct devlink_health_reporter *reporter; + unsigned long index, port_index; struct devlink_port *port; struct devlink *devlink; int start = cb->args[0]; - unsigned long index; int idx = 0; int err; @@ -7929,7 +8266,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { devl_lock(devlink); - list_for_each_entry(port, &devlink->port_list, list) { + xa_for_each(&devlink->ports, port_index, port) { mutex_lock(&port->reporters_lock); list_for_each_entry(reporter, &port->reporter_list, list) { if (idx < start) { @@ -8304,10 +8641,10 @@ static void devlink_trap_stats_read(struct devlink_stats __percpu *trap_stats, cpu_stats = per_cpu_ptr(trap_stats, i); do { - start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + start = u64_stats_fetch_begin(&cpu_stats->syncp); rx_packets = u64_stats_read(&cpu_stats->rx_packets); rx_bytes = u64_stats_read(&cpu_stats->rx_bytes); - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); u64_stats_add(&stats->rx_packets, rx_packets); u64_stats_add(&stats->rx_bytes, rx_bytes); @@ -9172,6 +9509,9 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { [DEVLINK_ATTR_LINECARD_INDEX] = { .type = NLA_U32 }, [DEVLINK_ATTR_LINECARD_TYPE] = { .type = NLA_NUL_STRING }, [DEVLINK_ATTR_SELFTESTS] = { .type = NLA_NESTED }, + [DEVLINK_ATTR_RATE_TX_PRIORITY] = { .type = NLA_U32 }, + [DEVLINK_ATTR_RATE_TX_WEIGHT] = { .type = NLA_U32 }, + [DEVLINK_ATTR_REGION_DIRECT] = { .type = NLA_FLAG }, }; static const struct genl_small_ops devlink_nl_ops[] = { @@ -9602,6 +9942,9 @@ void devlink_set_features(struct devlink *devlink, u64 features) } EXPORT_SYMBOL_GPL(devlink_set_features); +static int devlink_netdevice_event(struct notifier_block *nb, + unsigned long event, void *ptr); + /** * devlink_alloc_ns - Allocate new devlink instance resources * in specific namespace @@ -9632,16 +9975,19 @@ struct devlink *devlink_alloc_ns(const struct devlink_ops *ops, ret = xa_alloc_cyclic(&devlinks, &devlink->index, devlink, xa_limit_31b, &last_id, GFP_KERNEL); - if (ret < 0) { - kfree(devlink); - return NULL; - } + if (ret < 0) + goto err_xa_alloc; + + devlink->netdevice_nb.notifier_call = devlink_netdevice_event; + ret = register_netdevice_notifier_net(net, &devlink->netdevice_nb); + if (ret) + goto err_register_netdevice_notifier; devlink->dev = dev; devlink->ops = ops; + xa_init_flags(&devlink->ports, XA_FLAGS_ALLOC); xa_init_flags(&devlink->snapshot_ids, XA_FLAGS_ALLOC); write_pnet(&devlink->_net, net); - INIT_LIST_HEAD(&devlink->port_list); INIT_LIST_HEAD(&devlink->rate_list); INIT_LIST_HEAD(&devlink->linecard_list); INIT_LIST_HEAD(&devlink->sb_list); @@ -9662,6 +10008,12 @@ struct devlink *devlink_alloc_ns(const struct devlink_ops *ops, init_completion(&devlink->comp); return devlink; + +err_register_netdevice_notifier: + xa_erase(&devlinks, devlink->index); +err_xa_alloc: + kfree(devlink); + return NULL; } EXPORT_SYMBOL_GPL(devlink_alloc_ns); @@ -9687,12 +10039,13 @@ static void devlink_notify_register(struct devlink *devlink) struct devlink_linecard *linecard; struct devlink_rate *rate_node; struct devlink_region *region; + unsigned long port_index; devlink_notify(devlink, DEVLINK_CMD_NEW); list_for_each_entry(linecard, &devlink->linecard_list, list) devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); - list_for_each_entry(devlink_port, &devlink->port_list, list) + xa_for_each(&devlink->ports, port_index, devlink_port) devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); list_for_each_entry(policer_item, &devlink->trap_policer_list, list) @@ -9726,6 +10079,7 @@ static void devlink_notify_unregister(struct devlink *devlink) struct devlink_port *devlink_port; struct devlink_rate *rate_node; struct devlink_region *region; + unsigned long port_index; list_for_each_entry_reverse(param_item, &devlink->param_list, list) devlink_param_notify(devlink, 0, param_item, @@ -9748,7 +10102,7 @@ static void devlink_notify_unregister(struct devlink *devlink) devlink_trap_policer_notify(devlink, policer_item, DEVLINK_CMD_TRAP_POLICER_DEL); - list_for_each_entry_reverse(devlink_port, &devlink->port_list, list) + xa_for_each(&devlink->ports, port_index, devlink_port) devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_DEL); devlink_notify(devlink, DEVLINK_CMD_DEL); } @@ -9812,9 +10166,14 @@ void devlink_free(struct devlink *devlink) WARN_ON(!list_empty(&devlink->sb_list)); WARN_ON(!list_empty(&devlink->rate_list)); WARN_ON(!list_empty(&devlink->linecard_list)); - WARN_ON(!list_empty(&devlink->port_list)); + WARN_ON(!xa_empty(&devlink->ports)); xa_destroy(&devlink->snapshot_ids); + xa_destroy(&devlink->ports); + + WARN_ON_ONCE(unregister_netdevice_notifier_net(devlink_net(devlink), + &devlink->netdevice_nb)); + xa_erase(&devlinks, devlink->index); kfree(devlink); @@ -9909,10 +10268,9 @@ int devl_port_register(struct devlink *devlink, struct devlink_port *devlink_port, unsigned int port_index) { - devl_assert_locked(devlink); + int err; - if (devlink_port_index_exists(devlink, port_index)) - return -EEXIST; + devl_assert_locked(devlink); ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); @@ -9922,7 +10280,11 @@ int devl_port_register(struct devlink *devlink, spin_lock_init(&devlink_port->type_lock); INIT_LIST_HEAD(&devlink_port->reporter_list); mutex_init(&devlink_port->reporters_lock); - list_add_tail(&devlink_port->list, &devlink->port_list); + err = xa_insert(&devlink->ports, port_index, devlink_port, GFP_KERNEL); + if (err) { + mutex_destroy(&devlink_port->reporters_lock); + return err; + } INIT_DELAYED_WORK(&devlink_port->type_warn_dw, &devlink_port_type_warn); devlink_port_type_warn_schedule(devlink_port); @@ -9967,10 +10329,11 @@ EXPORT_SYMBOL_GPL(devlink_port_register); void devl_port_unregister(struct devlink_port *devlink_port) { lockdep_assert_held(&devlink_port->devlink->lock); + WARN_ON(devlink_port->type != DEVLINK_PORT_TYPE_NOTSET); devlink_port_type_warn_cancel(devlink_port); devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_DEL); - list_del(&devlink_port->list); + xa_erase(&devlink_port->devlink->ports, devlink_port->index); WARN_ON(!list_empty(&devlink_port->reporter_list)); mutex_destroy(&devlink_port->reporters_lock); devlink_port->registered = false; @@ -9994,20 +10357,6 @@ void devlink_port_unregister(struct devlink_port *devlink_port) } EXPORT_SYMBOL_GPL(devlink_port_unregister); -static void __devlink_port_type_set(struct devlink_port *devlink_port, - enum devlink_port_type type, - void *type_dev) -{ - ASSERT_DEVLINK_PORT_REGISTERED(devlink_port); - - devlink_port_type_warn_cancel(devlink_port); - spin_lock_bh(&devlink_port->type_lock); - devlink_port->type = type; - devlink_port->type_dev = type_dev; - spin_unlock_bh(&devlink_port->type_lock); - devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); -} - static void devlink_port_type_netdev_checks(struct devlink_port *devlink_port, struct net_device *netdev) { @@ -10045,23 +10394,58 @@ static void devlink_port_type_netdev_checks(struct devlink_port *devlink_port, } } +static void __devlink_port_type_set(struct devlink_port *devlink_port, + enum devlink_port_type type, + void *type_dev) +{ + struct net_device *netdev = type_dev; + + ASSERT_DEVLINK_PORT_REGISTERED(devlink_port); + + if (type == DEVLINK_PORT_TYPE_NOTSET) { + devlink_port_type_warn_schedule(devlink_port); + } else { + devlink_port_type_warn_cancel(devlink_port); + if (type == DEVLINK_PORT_TYPE_ETH && netdev) + devlink_port_type_netdev_checks(devlink_port, netdev); + } + + spin_lock_bh(&devlink_port->type_lock); + devlink_port->type = type; + switch (type) { + case DEVLINK_PORT_TYPE_ETH: + devlink_port->type_eth.netdev = netdev; + if (netdev) { + ASSERT_RTNL(); + devlink_port->type_eth.ifindex = netdev->ifindex; + BUILD_BUG_ON(sizeof(devlink_port->type_eth.ifname) != + sizeof(netdev->name)); + strcpy(devlink_port->type_eth.ifname, netdev->name); + } + break; + case DEVLINK_PORT_TYPE_IB: + devlink_port->type_ib.ibdev = type_dev; + break; + default: + break; + } + spin_unlock_bh(&devlink_port->type_lock); + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); +} + /** * devlink_port_type_eth_set - Set port type to Ethernet * * @devlink_port: devlink port - * @netdev: related netdevice + * + * If driver is calling this, most likely it is doing something wrong. */ -void devlink_port_type_eth_set(struct devlink_port *devlink_port, - struct net_device *netdev) +void devlink_port_type_eth_set(struct devlink_port *devlink_port) { - if (netdev) - devlink_port_type_netdev_checks(devlink_port, netdev); - else - dev_warn(devlink_port->devlink->dev, - "devlink port type for port %d set to Ethernet without a software interface reference, device type not supported by the kernel?\n", - devlink_port->index); - - __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_ETH, netdev); + dev_warn(devlink_port->devlink->dev, + "devlink port type for port %d set to Ethernet without a software interface reference, device type not supported by the kernel?\n", + devlink_port->index); + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_ETH, NULL); } EXPORT_SYMBOL_GPL(devlink_port_type_eth_set); @@ -10082,14 +10466,71 @@ EXPORT_SYMBOL_GPL(devlink_port_type_ib_set); * devlink_port_type_clear - Clear port type * * @devlink_port: devlink port + * + * If driver is calling this for clearing Ethernet type, most likely + * it is doing something wrong. */ void devlink_port_type_clear(struct devlink_port *devlink_port) { + if (devlink_port->type == DEVLINK_PORT_TYPE_ETH) + dev_warn(devlink_port->devlink->dev, + "devlink port type for port %d cleared without a software interface reference, device type not supported by the kernel?\n", + devlink_port->index); __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_NOTSET, NULL); - devlink_port_type_warn_schedule(devlink_port); } EXPORT_SYMBOL_GPL(devlink_port_type_clear); +static int devlink_netdevice_event(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct net_device *netdev = netdev_notifier_info_to_dev(ptr); + struct devlink_port *devlink_port = netdev->devlink_port; + struct devlink *devlink; + + devlink = container_of(nb, struct devlink, netdevice_nb); + + if (!devlink_port || devlink_port->devlink != devlink) + return NOTIFY_OK; + + switch (event) { + case NETDEV_POST_INIT: + /* Set the type but not netdev pointer. It is going to be set + * later on by NETDEV_REGISTER event. Happens once during + * netdevice register + */ + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_ETH, + NULL); + break; + case NETDEV_REGISTER: + case NETDEV_CHANGENAME: + /* Set the netdev on top of previously set type. Note this + * event happens also during net namespace change so here + * we take into account netdev pointer appearing in this + * namespace. + */ + __devlink_port_type_set(devlink_port, devlink_port->type, + netdev); + break; + case NETDEV_UNREGISTER: + /* Clear netdev pointer, but not the type. This event happens + * also during net namespace change so we need to clear + * pointer to netdev that is going to another net namespace. + */ + __devlink_port_type_set(devlink_port, devlink_port->type, + NULL); + break; + case NETDEV_PRE_UNINIT: + /* Clear the type and the netdev pointer. Happens one during + * netdevice unregister. + */ + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_NOTSET, + NULL); + break; + } + + return NOTIFY_OK; +} + static int __devlink_port_attrs_set(struct devlink_port *devlink_port, enum devlink_port_flavour flavour) { @@ -10211,13 +10652,60 @@ void devlink_port_attrs_pci_sf_set(struct devlink_port *devlink_port, u32 contro EXPORT_SYMBOL_GPL(devlink_port_attrs_pci_sf_set); /** + * devl_rate_node_create - create devlink rate node + * @devlink: devlink instance + * @priv: driver private data + * @node_name: name of the resulting node + * @parent: parent devlink_rate struct + * + * Create devlink rate object of type node + */ +struct devlink_rate * +devl_rate_node_create(struct devlink *devlink, void *priv, char *node_name, + struct devlink_rate *parent) +{ + struct devlink_rate *rate_node; + + rate_node = devlink_rate_node_get_by_name(devlink, node_name); + if (!IS_ERR(rate_node)) + return ERR_PTR(-EEXIST); + + rate_node = kzalloc(sizeof(*rate_node), GFP_KERNEL); + if (!rate_node) + return ERR_PTR(-ENOMEM); + + if (parent) { + rate_node->parent = parent; + refcount_inc(&rate_node->parent->refcnt); + } + + rate_node->type = DEVLINK_RATE_TYPE_NODE; + rate_node->devlink = devlink; + rate_node->priv = priv; + + rate_node->name = kstrdup(node_name, GFP_KERNEL); + if (!rate_node->name) { + kfree(rate_node); + return ERR_PTR(-ENOMEM); + } + + refcount_set(&rate_node->refcnt, 1); + list_add(&rate_node->list, &devlink->rate_list); + devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW); + return rate_node; +} +EXPORT_SYMBOL_GPL(devl_rate_node_create); + +/** * devl_rate_leaf_create - create devlink rate leaf * @devlink_port: devlink port object to create rate object on * @priv: driver private data + * @parent: parent devlink_rate struct * * Create devlink rate object of type leaf on provided @devlink_port. */ -int devl_rate_leaf_create(struct devlink_port *devlink_port, void *priv) +int devl_rate_leaf_create(struct devlink_port *devlink_port, void *priv, + struct devlink_rate *parent) { struct devlink *devlink = devlink_port->devlink; struct devlink_rate *devlink_rate; @@ -10231,6 +10719,11 @@ int devl_rate_leaf_create(struct devlink_port *devlink_port, void *priv) if (!devlink_rate) return -ENOMEM; + if (parent) { + devlink_rate->parent = parent; + refcount_inc(&devlink_rate->parent->refcnt); + } + devlink_rate->type = DEVLINK_RATE_TYPE_LEAF; devlink_rate->devlink = devlink; devlink_rate->devlink_port = devlink_port; @@ -11435,8 +11928,10 @@ void devl_region_destroy(struct devlink_region *region) devl_assert_locked(devlink); /* Free all snapshots of region */ + mutex_lock(®ion->snapshot_lock); list_for_each_entry_safe(snapshot, ts, ®ion->snapshot_list, list) devlink_region_snapshot_del(region, snapshot); + mutex_unlock(®ion->snapshot_lock); list_del(®ion->list); mutex_destroy(®ion->snapshot_lock); @@ -11624,6 +12119,8 @@ static const struct devlink_trap devlink_trap_generic[] = { DEVLINK_TRAP(ESP_PARSING, DROP), DEVLINK_TRAP(BLACKHOLE_NEXTHOP, DROP), DEVLINK_TRAP(DMAC_FILTER, DROP), + DEVLINK_TRAP(EAPOL, CONTROL), + DEVLINK_TRAP(LOCKED_PORT, DROP), }; #define DEVLINK_TRAP_GROUP(_id) \ @@ -11659,6 +12156,7 @@ static const struct devlink_trap_group devlink_trap_group_generic[] = { DEVLINK_TRAP_GROUP(ACL_SAMPLE), DEVLINK_TRAP_GROUP(ACL_TRAP), DEVLINK_TRAP_GROUP(PARSER_ERROR_DROPS), + DEVLINK_TRAP_GROUP(EAPOL), }; static int devlink_trap_generic_verify(const struct devlink_trap *trap) @@ -12016,7 +12514,7 @@ devlink_trap_report_metadata_set(struct devlink_trap_metadata *metadata, spin_lock(&in_devlink_port->type_lock); if (in_devlink_port->type == DEVLINK_PORT_TYPE_ETH) - metadata->input_dev = in_devlink_port->type_dev; + metadata->input_dev = in_devlink_port->type_eth.netdev; spin_unlock(&in_devlink_port->type_lock); } @@ -12416,14 +12914,6 @@ free_msg: nlmsg_free(msg); } -static struct devlink_port *netdev_to_devlink_port(struct net_device *dev) -{ - if (!dev->netdev_ops->ndo_get_devlink_port) - return NULL; - - return dev->netdev_ops->ndo_get_devlink_port(dev); -} - void devlink_compat_running_version(struct devlink *devlink, char *buf, size_t len) { @@ -12469,7 +12959,7 @@ int devlink_compat_phys_port_name_get(struct net_device *dev, */ ASSERT_RTNL(); - devlink_port = netdev_to_devlink_port(dev); + devlink_port = dev->devlink_port; if (!devlink_port) return -EOPNOTSUPP; @@ -12485,7 +12975,7 @@ int devlink_compat_switch_id_get(struct net_device *dev, * devlink_port instance cannot disappear in the middle. No need to take * any devlink lock as only permanent values are accessed. */ - devlink_port = netdev_to_devlink_port(dev); + devlink_port = dev->devlink_port; if (!devlink_port || !devlink_port->switch_port) return -EOPNOTSUPP; diff --git a/net/core/drop_monitor.c b/net/core/drop_monitor.c index f084a4a6b7ab..5a782d1d8fd3 100644 --- a/net/core/drop_monitor.c +++ b/net/core/drop_monitor.c @@ -1432,9 +1432,9 @@ static void net_dm_stats_read(struct net_dm_stats *stats) u64 dropped; do { - start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + start = u64_stats_fetch_begin(&cpu_stats->syncp); dropped = u64_stats_read(&cpu_stats->dropped); - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); u64_stats_add(&stats->dropped, dropped); } @@ -1476,9 +1476,9 @@ static void net_dm_hw_stats_read(struct net_dm_stats *stats) u64 dropped; do { - start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + start = u64_stats_fetch_begin(&cpu_stats->syncp); dropped = u64_stats_read(&cpu_stats->dropped); - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); u64_stats_add(&stats->dropped, dropped); } @@ -1620,7 +1620,7 @@ static const struct genl_small_ops dropmon_ops[] = { }, }; -static int net_dm_nl_pre_doit(const struct genl_ops *ops, +static int net_dm_nl_pre_doit(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { mutex_lock(&net_dm_mutex); @@ -1628,7 +1628,7 @@ static int net_dm_nl_pre_doit(const struct genl_ops *ops, return 0; } -static void net_dm_nl_post_doit(const struct genl_ops *ops, +static void net_dm_nl_post_doit(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { mutex_unlock(&net_dm_mutex); diff --git a/net/core/dst.c b/net/core/dst.c index bc9c9be4e080..6d2dd03dafa8 100644 --- a/net/core/dst.c +++ b/net/core/dst.c @@ -174,7 +174,7 @@ void dst_release(struct dst_entry *dst) net_warn_ratelimited("%s: dst:%p refcnt:%d\n", __func__, dst, newrefcnt); if (!newrefcnt) - call_rcu(&dst->rcu_head, dst_destroy_rcu); + call_rcu_hurry(&dst->rcu_head, dst_destroy_rcu); } } EXPORT_SYMBOL(dst_release); @@ -316,6 +316,8 @@ void metadata_dst_free(struct metadata_dst *md_dst) if (md_dst->type == METADATA_IP_TUNNEL) dst_cache_destroy(&md_dst->u.tun_info.dst_cache); #endif + if (md_dst->type == METADATA_XFRM) + dst_release(md_dst->u.xfrm_info.dst_orig); kfree(md_dst); } EXPORT_SYMBOL_GPL(metadata_dst_free); @@ -340,16 +342,18 @@ EXPORT_SYMBOL_GPL(metadata_dst_alloc_percpu); void metadata_dst_free_percpu(struct metadata_dst __percpu *md_dst) { -#ifdef CONFIG_DST_CACHE int cpu; for_each_possible_cpu(cpu) { struct metadata_dst *one_md_dst = per_cpu_ptr(md_dst, cpu); +#ifdef CONFIG_DST_CACHE if (one_md_dst->type == METADATA_IP_TUNNEL) dst_cache_destroy(&one_md_dst->u.tun_info.dst_cache); - } #endif + if (one_md_dst->type == METADATA_XFRM) + dst_release(one_md_dst->u.xfrm_info.dst_orig); + } free_percpu(md_dst); } EXPORT_SYMBOL_GPL(metadata_dst_free_percpu); diff --git a/net/core/failover.c b/net/core/failover.c index 864d2d83eff4..2a140b3ea669 100644 --- a/net/core/failover.c +++ b/net/core/failover.c @@ -80,14 +80,14 @@ static int failover_slave_register(struct net_device *slave_dev) goto err_upper_link; } - slave_dev->priv_flags |= (IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); + slave_dev->priv_flags |= (IFF_FAILOVER_SLAVE | IFF_NO_ADDRCONF); if (fops && fops->slave_register && !fops->slave_register(slave_dev, failover_dev)) return NOTIFY_OK; netdev_upper_dev_unlink(slave_dev, failover_dev); - slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); + slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_NO_ADDRCONF); err_upper_link: netdev_rx_handler_unregister(slave_dev); done: @@ -121,7 +121,7 @@ int failover_slave_unregister(struct net_device *slave_dev) netdev_rx_handler_unregister(slave_dev); netdev_upper_dev_unlink(slave_dev, failover_dev); - slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); + slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_NO_ADDRCONF); if (fops && fops->slave_unregister && !fops->slave_unregister(slave_dev, failover_dev)) diff --git a/net/core/filter.c b/net/core/filter.c index bb0136e7a8e4..43cc1fe58a2c 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -80,6 +80,7 @@ #include <net/tls.h> #include <net/xdp.h> #include <net/mptcp.h> +#include <net/netfilter/nf_conntrack_bpf.h> static const struct bpf_func_proto * bpf_sk_base_func_proto(enum bpf_func_id func_id); @@ -325,11 +326,11 @@ static u32 convert_skb_access(int skb_field, int dst_reg, int src_reg, offsetof(struct sk_buff, vlan_tci)); break; case SKF_AD_VLAN_TAG_PRESENT: - *insn++ = BPF_LDX_MEM(BPF_B, dst_reg, src_reg, PKT_VLAN_PRESENT_OFFSET); - if (PKT_VLAN_PRESENT_BIT) - *insn++ = BPF_ALU32_IMM(BPF_RSH, dst_reg, PKT_VLAN_PRESENT_BIT); - if (PKT_VLAN_PRESENT_BIT < 7) - *insn++ = BPF_ALU32_IMM(BPF_AND, dst_reg, 1); + BUILD_BUG_ON(sizeof_field(struct sk_buff, vlan_all) != 4); + *insn++ = BPF_LDX_MEM(BPF_W, dst_reg, src_reg, + offsetof(struct sk_buff, vlan_all)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, dst_reg, 0, 1); + *insn++ = BPF_ALU32_IMM(BPF_MOV, dst_reg, 1); break; } @@ -2124,6 +2125,11 @@ static int __bpf_redirect_no_mac(struct sk_buff *skb, struct net_device *dev, { unsigned int mlen = skb_network_offset(skb); + if (unlikely(skb->len <= mlen)) { + kfree_skb(skb); + return -ERANGE; + } + if (mlen) { __skb_pull(skb, mlen); @@ -2145,7 +2151,7 @@ static int __bpf_redirect_common(struct sk_buff *skb, struct net_device *dev, u32 flags) { /* Verify that a link layer header is carried */ - if (unlikely(skb->mac_header >= skb->network_header)) { + if (unlikely(skb->mac_header >= skb->network_header || skb->len == 0)) { kfree_skb(skb); return -ERANGE; } @@ -3174,15 +3180,18 @@ static int bpf_skb_generic_push(struct sk_buff *skb, u32 off, u32 len) static int bpf_skb_generic_pop(struct sk_buff *skb, u32 off, u32 len) { + void *old_data; + /* skb_ensure_writable() is not needed here, as we're * already working on an uncloned skb. */ if (unlikely(!pskb_may_pull(skb, off + len))) return -ENOMEM; - skb_postpull_rcsum(skb, skb->data + off, len); - memmove(skb->data + len, skb->data, off); + old_data = skb->data; __skb_pull(skb, len); + skb_postpull_rcsum(skb, old_data + off, len); + memmove(skb->data, old_data, off); return 0; } @@ -4104,7 +4113,10 @@ static const struct bpf_func_proto bpf_xdp_adjust_meta_proto = { .arg2_type = ARG_ANYTHING, }; -/* XDP_REDIRECT works by a three-step process, implemented in the functions +/** + * DOC: xdp redirect + * + * XDP_REDIRECT works by a three-step process, implemented in the functions * below: * * 1. The bpf_redirect() and bpf_redirect_map() helpers will lookup the target @@ -4119,7 +4131,8 @@ static const struct bpf_func_proto bpf_xdp_adjust_meta_proto = { * 3. Before exiting its NAPI poll loop, the driver will call xdp_do_flush(), * which will flush all the different bulk queues, thus completing the * redirect. - * + */ +/* * Pointers to the map entries will be kept around for this whole sequence of * steps, protected by RCU. However, there is no top-level rcu_read_lock() in * the core code; instead, the RCU protection relies on everything happening @@ -4410,10 +4423,10 @@ static const struct bpf_func_proto bpf_xdp_redirect_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_3(bpf_xdp_redirect_map, struct bpf_map *, map, u32, ifindex, +BPF_CALL_3(bpf_xdp_redirect_map, struct bpf_map *, map, u64, key, u64, flags) { - return map->ops->map_redirect(map, ifindex, flags); + return map->ops->map_redirect(map, key, flags); } static const struct bpf_func_proto bpf_xdp_redirect_map_proto = { @@ -5621,6 +5634,15 @@ static const struct bpf_func_proto bpf_bind_proto = { }; #ifdef CONFIG_XFRM + +#if (IS_BUILTIN(CONFIG_XFRM_INTERFACE) && IS_ENABLED(CONFIG_DEBUG_INFO_BTF)) || \ + (IS_MODULE(CONFIG_XFRM_INTERFACE) && IS_ENABLED(CONFIG_DEBUG_INFO_BTF_MODULES)) + +struct metadata_dst __percpu *xfrm_bpf_md_dst; +EXPORT_SYMBOL_GPL(xfrm_bpf_md_dst); + +#endif + BPF_CALL_5(bpf_skb_get_xfrm_state, struct sk_buff *, skb, u32, index, struct bpf_xfrm_state *, to, u32, size, u64, flags) { @@ -6428,7 +6450,7 @@ static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, else sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport, dst4, tuple->ipv4.dport, - dif, sdif, &udp_table, NULL); + dif, sdif, net->ipv4.udp_table, NULL); #if IS_ENABLED(CONFIG_IPV6) } else { struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr; @@ -6444,7 +6466,7 @@ static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, src6, tuple->ipv6.sport, dst6, tuple->ipv6.dport, dif, sdif, - &udp_table, NULL); + net->ipv4.udp_table, NULL); #endif } @@ -7983,6 +8005,19 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) default: return bpf_sk_base_func_proto(func_id); } + +#if IS_MODULE(CONFIG_NF_CONNTRACK) && IS_ENABLED(CONFIG_DEBUG_INFO_BTF_MODULES) + /* The nf_conn___init type is used in the NF_CONNTRACK kfuncs. The + * kfuncs are defined in two different modules, and we want to be able + * to use them interchangably with the same BTF type ID. Because modules + * can't de-duplicate BTF IDs between each other, we need the type to be + * referenced in the vmlinux BTF or the verifier will get confused about + * the different types. So we add this dummy type reference which will + * be included in vmlinux BTF, allowing both modules to refer to the + * same type ID. + */ + BTF_TYPE_EMIT(struct nf_conn___init); +#endif } const struct bpf_func_proto bpf_sock_map_update_proto __weak; @@ -8647,28 +8682,25 @@ static bool tc_cls_act_is_valid_access(int off, int size, DEFINE_MUTEX(nf_conn_btf_access_lock); EXPORT_SYMBOL_GPL(nf_conn_btf_access_lock); -int (*nfct_btf_struct_access)(struct bpf_verifier_log *log, const struct btf *btf, - const struct btf_type *t, int off, int size, - enum bpf_access_type atype, u32 *next_btf_id, - enum bpf_type_flag *flag); +int (*nfct_btf_struct_access)(struct bpf_verifier_log *log, + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, + u32 *next_btf_id, enum bpf_type_flag *flag); EXPORT_SYMBOL_GPL(nfct_btf_struct_access); static int tc_cls_act_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, - u32 *next_btf_id, - enum bpf_type_flag *flag) + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, + u32 *next_btf_id, enum bpf_type_flag *flag) { int ret = -EACCES; if (atype == BPF_READ) - return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, - flag); + return btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); mutex_lock(&nf_conn_btf_access_lock); if (nfct_btf_struct_access) - ret = nfct_btf_struct_access(log, btf, t, off, size, atype, next_btf_id, flag); + ret = nfct_btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); mutex_unlock(&nf_conn_btf_access_lock); return ret; @@ -8734,21 +8766,18 @@ void bpf_warn_invalid_xdp_action(struct net_device *dev, struct bpf_prog *prog, EXPORT_SYMBOL_GPL(bpf_warn_invalid_xdp_action); static int xdp_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, - u32 *next_btf_id, - enum bpf_type_flag *flag) + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, + u32 *next_btf_id, enum bpf_type_flag *flag) { int ret = -EACCES; if (atype == BPF_READ) - return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, - flag); + return btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); mutex_lock(&nf_conn_btf_access_lock); if (nfct_btf_struct_access) - ret = nfct_btf_struct_access(log, btf, t, off, size, atype, next_btf_id, flag); + ret = nfct_btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); mutex_unlock(&nf_conn_btf_access_lock); return ret; @@ -8921,6 +8950,10 @@ static bool sock_ops_is_valid_access(int off, int size, bpf_ctx_record_field_size(info, size_default); return bpf_ctx_narrow_access_ok(off, size, size_default); + case offsetof(struct bpf_sock_ops, skb_hwtstamp): + if (size != sizeof(__u64)) + return false; + break; default: if (size != size_default) return false; @@ -9104,21 +9137,21 @@ static struct bpf_insn *bpf_convert_tstamp_type_read(const struct bpf_insn *si, return insn; } -static struct bpf_insn *bpf_convert_shinfo_access(const struct bpf_insn *si, +static struct bpf_insn *bpf_convert_shinfo_access(__u8 dst_reg, __u8 skb_reg, struct bpf_insn *insn) { /* si->dst_reg = skb_shinfo(SKB); */ #ifdef NET_SKBUFF_DATA_USES_OFFSET *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), - BPF_REG_AX, si->src_reg, + BPF_REG_AX, skb_reg, offsetof(struct sk_buff, end)); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head), - si->dst_reg, si->src_reg, + dst_reg, skb_reg, offsetof(struct sk_buff, head)); - *insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX); + *insn++ = BPF_ALU64_REG(BPF_ADD, dst_reg, BPF_REG_AX); #else *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), - si->dst_reg, si->src_reg, + dst_reg, skb_reg, offsetof(struct sk_buff, end)); #endif @@ -9290,13 +9323,11 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct __sk_buff, vlan_present): - *target_size = 1; - *insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->src_reg, - PKT_VLAN_PRESENT_OFFSET); - if (PKT_VLAN_PRESENT_BIT) - *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, PKT_VLAN_PRESENT_BIT); - if (PKT_VLAN_PRESENT_BIT < 7) - *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, 1); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, + vlan_all, 4, target_size)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_ALU32_IMM(BPF_MOV, si->dst_reg, 1); break; case offsetof(struct __sk_buff, vlan_tci): @@ -9511,7 +9542,7 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct __sk_buff, gso_segs): - insn = bpf_convert_shinfo_access(si, insn); + insn = bpf_convert_shinfo_access(si->dst_reg, si->src_reg, insn); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_segs), si->dst_reg, si->dst_reg, bpf_target_off(struct skb_shared_info, @@ -9519,7 +9550,7 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, target_size)); break; case offsetof(struct __sk_buff, gso_size): - insn = bpf_convert_shinfo_access(si, insn); + insn = bpf_convert_shinfo_access(si->dst_reg, si->src_reg, insn); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_size), si->dst_reg, si->dst_reg, bpf_target_off(struct skb_shared_info, @@ -9546,7 +9577,7 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(sizeof_field(struct skb_shared_hwtstamps, hwtstamp) != 8); BUILD_BUG_ON(offsetof(struct skb_shared_hwtstamps, hwtstamp) != 0); - insn = bpf_convert_shinfo_access(si, insn); + insn = bpf_convert_shinfo_access(si->dst_reg, si->src_reg, insn); *insn++ = BPF_LDX_MEM(BPF_DW, si->dst_reg, si->dst_reg, bpf_target_off(struct skb_shared_info, @@ -10396,6 +10427,25 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type, tcp_flags), si->dst_reg, si->dst_reg, off); break; + case offsetof(struct bpf_sock_ops, skb_hwtstamp): { + struct bpf_insn *jmp_on_null_skb; + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + skb), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + skb)); + /* Reserve one insn to test skb == NULL */ + jmp_on_null_skb = insn++; + insn = bpf_convert_shinfo_access(si->dst_reg, si->dst_reg, insn); + *insn++ = BPF_LDX_MEM(BPF_DW, si->dst_reg, si->dst_reg, + bpf_target_off(struct skb_shared_info, + hwtstamps, 8, + target_size)); + *jmp_on_null_skb = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, + insn - jmp_on_null_skb - 1); + break; + } } return insn - insn_buf; } diff --git a/net/core/flow_dissector.c b/net/core/flow_dissector.c index 25cd35f5922e..25fb0bbc310f 100644 --- a/net/core/flow_dissector.c +++ b/net/core/flow_dissector.c @@ -296,7 +296,7 @@ skb_flow_dissect_ct(const struct sk_buff *skb, key->ct_zone = ct->zone.id; #endif #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) - key->ct_mark = ct->mark; + key->ct_mark = READ_ONCE(ct->mark); #endif cl = nf_ct_labels_find(ct); @@ -971,12 +971,14 @@ bool __skb_flow_dissect(const struct net *net, #if IS_ENABLED(CONFIG_NET_DSA) if (unlikely(skb->dev && netdev_uses_dsa(skb->dev) && proto == htons(ETH_P_XDSA))) { + struct metadata_dst *md_dst = skb_metadata_dst(skb); const struct dsa_device_ops *ops; int offset = 0; ops = skb->dev->dsa_ptr->tag_ops; /* Only DSA header taggers break flow dissection */ - if (ops->needed_headroom) { + if (ops->needed_headroom && + (!md_dst || md_dst->type != METADATA_HW_PORT_MUX)) { if (ops->flow_dissect) ops->flow_dissect(skb, &proto, &offset); else diff --git a/net/core/flow_offload.c b/net/core/flow_offload.c index abe423fd5736..acfc1f88ea79 100644 --- a/net/core/flow_offload.c +++ b/net/core/flow_offload.c @@ -97,6 +97,13 @@ void flow_rule_match_cvlan(const struct flow_rule *rule, } EXPORT_SYMBOL(flow_rule_match_cvlan); +void flow_rule_match_arp(const struct flow_rule *rule, + struct flow_match_arp *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ARP, out); +} +EXPORT_SYMBOL(flow_rule_match_arp); + void flow_rule_match_ipv4_addrs(const struct flow_rule *rule, struct flow_match_ipv4_addrs *out) { diff --git a/net/core/gen_estimator.c b/net/core/gen_estimator.c index 4fcbdd71c59f..fae9c4694186 100644 --- a/net/core/gen_estimator.c +++ b/net/core/gen_estimator.c @@ -208,7 +208,7 @@ void gen_kill_estimator(struct net_rate_estimator __rcu **rate_est) est = xchg((__force struct net_rate_estimator **)rate_est, NULL); if (est) { - del_timer_sync(&est->timer); + timer_shutdown_sync(&est->timer); kfree_rcu(est, rcu); } } diff --git a/net/core/gen_stats.c b/net/core/gen_stats.c index c8d137ef5980..b71ccaec0991 100644 --- a/net/core/gen_stats.c +++ b/net/core/gen_stats.c @@ -135,10 +135,10 @@ static void gnet_stats_add_basic_cpu(struct gnet_stats_basic_sync *bstats, u64 bytes, packets; do { - start = u64_stats_fetch_begin_irq(&bcpu->syncp); + start = u64_stats_fetch_begin(&bcpu->syncp); bytes = u64_stats_read(&bcpu->bytes); packets = u64_stats_read(&bcpu->packets); - } while (u64_stats_fetch_retry_irq(&bcpu->syncp, start)); + } while (u64_stats_fetch_retry(&bcpu->syncp, start)); t_bytes += bytes; t_packets += packets; @@ -162,10 +162,10 @@ void gnet_stats_add_basic(struct gnet_stats_basic_sync *bstats, } do { if (running) - start = u64_stats_fetch_begin_irq(&b->syncp); + start = u64_stats_fetch_begin(&b->syncp); bytes = u64_stats_read(&b->bytes); packets = u64_stats_read(&b->packets); - } while (running && u64_stats_fetch_retry_irq(&b->syncp, start)); + } while (running && u64_stats_fetch_retry(&b->syncp, start)); _bstats_update(bstats, bytes, packets); } @@ -187,10 +187,10 @@ static void gnet_stats_read_basic(u64 *ret_bytes, u64 *ret_packets, u64 bytes, packets; do { - start = u64_stats_fetch_begin_irq(&bcpu->syncp); + start = u64_stats_fetch_begin(&bcpu->syncp); bytes = u64_stats_read(&bcpu->bytes); packets = u64_stats_read(&bcpu->packets); - } while (u64_stats_fetch_retry_irq(&bcpu->syncp, start)); + } while (u64_stats_fetch_retry(&bcpu->syncp, start)); t_bytes += bytes; t_packets += packets; @@ -201,10 +201,10 @@ static void gnet_stats_read_basic(u64 *ret_bytes, u64 *ret_packets, } do { if (running) - start = u64_stats_fetch_begin_irq(&b->syncp); + start = u64_stats_fetch_begin(&b->syncp); *ret_bytes = u64_stats_read(&b->bytes); *ret_packets = u64_stats_read(&b->packets); - } while (running && u64_stats_fetch_retry_irq(&b->syncp, start)); + } while (running && u64_stats_fetch_retry(&b->syncp, start)); } static int diff --git a/net/core/gro.c b/net/core/gro.c index bc9451743307..4bac7ea6e025 100644 --- a/net/core/gro.c +++ b/net/core/gro.c @@ -162,6 +162,15 @@ int skb_gro_receive(struct sk_buff *p, struct sk_buff *skb) struct sk_buff *lp; int segs; + /* Do not splice page pool based packets w/ non-page pool + * packets. This can result in reference count issues as page + * pool pages will not decrement the reference count and will + * instead be immediately returned to the pool or have frag + * count decremented. + */ + if (p->pp_recycle != skb->pp_recycle) + return -ETOOMANYREFS; + /* pairs with WRITE_ONCE() in netif_set_gro_max_size() */ gro_max_size = READ_ONCE(p->dev->gro_max_size); @@ -370,9 +379,7 @@ static void gro_list_prepare(const struct list_head *head, } diffs = (unsigned long)p->dev ^ (unsigned long)skb->dev; - diffs |= skb_vlan_tag_present(p) ^ skb_vlan_tag_present(skb); - if (skb_vlan_tag_present(p)) - diffs |= skb_vlan_tag_get(p) ^ skb_vlan_tag_get(skb); + diffs |= p->vlan_all ^ skb->vlan_all; diffs |= skb_metadata_differs(p, skb); if (maclen == ETH_HLEN) diffs |= compare_ether_header(skb_mac_header(p), @@ -489,45 +496,46 @@ static enum gro_result dev_gro_receive(struct napi_struct *napi, struct sk_buff rcu_read_lock(); list_for_each_entry_rcu(ptype, head, list) { - if (ptype->type != type || !ptype->callbacks.gro_receive) - continue; - - skb_set_network_header(skb, skb_gro_offset(skb)); - skb_reset_mac_len(skb); - BUILD_BUG_ON(sizeof_field(struct napi_gro_cb, zeroed) != sizeof(u32)); - BUILD_BUG_ON(!IS_ALIGNED(offsetof(struct napi_gro_cb, zeroed), - sizeof(u32))); /* Avoid slow unaligned acc */ - *(u32 *)&NAPI_GRO_CB(skb)->zeroed = 0; - NAPI_GRO_CB(skb)->flush = skb_has_frag_list(skb); - NAPI_GRO_CB(skb)->is_atomic = 1; - NAPI_GRO_CB(skb)->count = 1; - if (unlikely(skb_is_gso(skb))) { - NAPI_GRO_CB(skb)->count = skb_shinfo(skb)->gso_segs; - /* Only support TCP at the moment. */ - if (!skb_is_gso_tcp(skb)) - NAPI_GRO_CB(skb)->flush = 1; - } - - /* Setup for GRO checksum validation */ - switch (skb->ip_summed) { - case CHECKSUM_COMPLETE: - NAPI_GRO_CB(skb)->csum = skb->csum; - NAPI_GRO_CB(skb)->csum_valid = 1; - break; - case CHECKSUM_UNNECESSARY: - NAPI_GRO_CB(skb)->csum_cnt = skb->csum_level + 1; - break; - } + if (ptype->type == type && ptype->callbacks.gro_receive) + goto found_ptype; + } + rcu_read_unlock(); + goto normal; + +found_ptype: + skb_set_network_header(skb, skb_gro_offset(skb)); + skb_reset_mac_len(skb); + BUILD_BUG_ON(sizeof_field(struct napi_gro_cb, zeroed) != sizeof(u32)); + BUILD_BUG_ON(!IS_ALIGNED(offsetof(struct napi_gro_cb, zeroed), + sizeof(u32))); /* Avoid slow unaligned acc */ + *(u32 *)&NAPI_GRO_CB(skb)->zeroed = 0; + NAPI_GRO_CB(skb)->flush = skb_has_frag_list(skb); + NAPI_GRO_CB(skb)->is_atomic = 1; + NAPI_GRO_CB(skb)->count = 1; + if (unlikely(skb_is_gso(skb))) { + NAPI_GRO_CB(skb)->count = skb_shinfo(skb)->gso_segs; + /* Only support TCP and non DODGY users. */ + if (!skb_is_gso_tcp(skb) || + (skb_shinfo(skb)->gso_type & SKB_GSO_DODGY)) + NAPI_GRO_CB(skb)->flush = 1; + } - pp = INDIRECT_CALL_INET(ptype->callbacks.gro_receive, - ipv6_gro_receive, inet_gro_receive, - &gro_list->list, skb); + /* Setup for GRO checksum validation */ + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + NAPI_GRO_CB(skb)->csum = skb->csum; + NAPI_GRO_CB(skb)->csum_valid = 1; + break; + case CHECKSUM_UNNECESSARY: + NAPI_GRO_CB(skb)->csum_cnt = skb->csum_level + 1; break; } - rcu_read_unlock(); - if (&ptype->list == head) - goto normal; + pp = INDIRECT_CALL_INET(ptype->callbacks.gro_receive, + ipv6_gro_receive, inet_gro_receive, + &gro_list->list, skb); + + rcu_read_unlock(); if (PTR_ERR(pp) == -EINPROGRESS) { ret = GRO_CONSUMED; diff --git a/net/core/link_watch.c b/net/core/link_watch.c index aa6cb1f90966..c469d1c4db5d 100644 --- a/net/core/link_watch.c +++ b/net/core/link_watch.c @@ -38,9 +38,23 @@ static unsigned char default_operstate(const struct net_device *dev) if (netif_testing(dev)) return IF_OPER_TESTING; - if (!netif_carrier_ok(dev)) - return (dev->ifindex != dev_get_iflink(dev) ? - IF_OPER_LOWERLAYERDOWN : IF_OPER_DOWN); + /* Some uppers (DSA) have additional sources for being down, so + * first check whether lower is indeed the source of its down state. + */ + if (!netif_carrier_ok(dev)) { + int iflink = dev_get_iflink(dev); + struct net_device *peer; + + if (iflink == dev->ifindex) + return IF_OPER_DOWN; + + peer = __dev_get_by_index(dev_net(dev), iflink); + if (!peer) + return IF_OPER_DOWN; + + return netif_carrier_ok(peer) ? IF_OPER_DOWN : + IF_OPER_LOWERLAYERDOWN; + } if (netif_dormant(dev)) return IF_OPER_DORMANT; diff --git a/net/core/lwtunnel.c b/net/core/lwtunnel.c index 6fac2f0ef074..711cd3b4347a 100644 --- a/net/core/lwtunnel.c +++ b/net/core/lwtunnel.c @@ -48,9 +48,11 @@ static const char *lwtunnel_encap_str(enum lwtunnel_encap_types encap_type) return "RPL"; case LWTUNNEL_ENCAP_IOAM6: return "IOAM6"; + case LWTUNNEL_ENCAP_XFRM: + /* module autoload not supported for encap type */ + return NULL; case LWTUNNEL_ENCAP_IP6: case LWTUNNEL_ENCAP_IP: - case LWTUNNEL_ENCAP_XFRM: case LWTUNNEL_ENCAP_NONE: case __LWTUNNEL_ENCAP_MAX: /* should not have got here */ diff --git a/net/core/neighbour.c b/net/core/neighbour.c index a77a85e357e0..f00a79fc301b 100644 --- a/net/core/neighbour.c +++ b/net/core/neighbour.c @@ -111,7 +111,7 @@ static void neigh_cleanup_and_release(struct neighbour *neigh) unsigned long neigh_rand_reach_time(unsigned long base) { - return base ? prandom_u32_max(base) + (base >> 1) : 0; + return base ? get_random_u32_below(base) + (base >> 1) : 0; } EXPORT_SYMBOL(neigh_rand_reach_time); @@ -307,7 +307,31 @@ static int neigh_del_timer(struct neighbour *n) return 0; } -static void pneigh_queue_purge(struct sk_buff_head *list, struct net *net) +static struct neigh_parms *neigh_get_dev_parms_rcu(struct net_device *dev, + int family) +{ + switch (family) { + case AF_INET: + return __in_dev_arp_parms_get_rcu(dev); + case AF_INET6: + return __in6_dev_nd_parms_get_rcu(dev); + } + return NULL; +} + +static void neigh_parms_qlen_dec(struct net_device *dev, int family) +{ + struct neigh_parms *p; + + rcu_read_lock(); + p = neigh_get_dev_parms_rcu(dev, family); + if (p) + p->qlen--; + rcu_read_unlock(); +} + +static void pneigh_queue_purge(struct sk_buff_head *list, struct net *net, + int family) { struct sk_buff_head tmp; unsigned long flags; @@ -321,13 +345,7 @@ static void pneigh_queue_purge(struct sk_buff_head *list, struct net *net) struct net_device *dev = skb->dev; if (net == NULL || net_eq(dev_net(dev), net)) { - struct in_device *in_dev; - - rcu_read_lock(); - in_dev = __in_dev_get_rcu(dev); - if (in_dev) - in_dev->arp_parms->qlen--; - rcu_read_unlock(); + neigh_parms_qlen_dec(dev, family); __skb_unlink(skb, list); __skb_queue_tail(&tmp, skb); } @@ -409,7 +427,8 @@ static int __neigh_ifdown(struct neigh_table *tbl, struct net_device *dev, write_lock_bh(&tbl->lock); neigh_flush_dev(tbl, dev, skip_perm); pneigh_ifdown_and_unlock(tbl, dev); - pneigh_queue_purge(&tbl->proxy_queue, dev ? dev_net(dev) : NULL); + pneigh_queue_purge(&tbl->proxy_queue, dev ? dev_net(dev) : NULL, + tbl->family); if (skb_queue_empty_lockless(&tbl->proxy_queue)) del_timer_sync(&tbl->proxy_timer); return 0; @@ -1621,13 +1640,8 @@ static void neigh_proxy_process(struct timer_list *t) if (tdif <= 0) { struct net_device *dev = skb->dev; - struct in_device *in_dev; - rcu_read_lock(); - in_dev = __in_dev_get_rcu(dev); - if (in_dev) - in_dev->arp_parms->qlen--; - rcu_read_unlock(); + neigh_parms_qlen_dec(dev, tbl->family); __skb_unlink(skb, &tbl->proxy_queue); if (tbl->proxy_redo && netif_running(dev)) { @@ -1652,7 +1666,7 @@ void pneigh_enqueue(struct neigh_table *tbl, struct neigh_parms *p, struct sk_buff *skb) { unsigned long sched_next = jiffies + - prandom_u32_max(NEIGH_VAR(p, PROXY_DELAY)); + get_random_u32_below(NEIGH_VAR(p, PROXY_DELAY)); if (p->qlen > NEIGH_VAR(p, PROXY_QLEN)) { kfree_skb(skb); @@ -1821,7 +1835,7 @@ int neigh_table_clear(int index, struct neigh_table *tbl) cancel_delayed_work_sync(&tbl->managed_work); cancel_delayed_work_sync(&tbl->gc_work); del_timer_sync(&tbl->proxy_timer); - pneigh_queue_purge(&tbl->proxy_queue, NULL); + pneigh_queue_purge(&tbl->proxy_queue, NULL, tbl->family); neigh_ifdown(tbl, NULL); if (atomic_read(&tbl->entries)) pr_crit("neighbour leakage\n"); @@ -3539,18 +3553,6 @@ static int proc_unres_qlen(struct ctl_table *ctl, int write, return ret; } -static struct neigh_parms *neigh_get_dev_parms_rcu(struct net_device *dev, - int family) -{ - switch (family) { - case AF_INET: - return __in_dev_arp_parms_get_rcu(dev); - case AF_INET6: - return __in6_dev_nd_parms_get_rcu(dev); - } - return NULL; -} - static void neigh_copy_dflt_parms(struct net *net, struct neigh_parms *p, int index) { diff --git a/net/core/net-sysfs.c b/net/core/net-sysfs.c index 8409d41405df..ca55dd747d6c 100644 --- a/net/core/net-sysfs.c +++ b/net/core/net-sysfs.c @@ -532,7 +532,7 @@ static ssize_t phys_port_name_show(struct device *dev, * returning early without hitting the trylock/restart below. */ if (!netdev->netdev_ops->ndo_get_phys_port_name && - !netdev->netdev_ops->ndo_get_devlink_port) + !netdev->devlink_port) return -EOPNOTSUPP; if (!rtnl_trylock()) @@ -562,7 +562,7 @@ static ssize_t phys_switch_id_show(struct device *dev, * because recurse is false when calling dev_get_port_parent_id. */ if (!netdev->netdev_ops->ndo_get_port_parent_id && - !netdev->netdev_ops->ndo_get_devlink_port) + !netdev->devlink_port) return -EOPNOTSUPP; if (!rtnl_trylock()) @@ -1020,7 +1020,7 @@ static void rx_queue_release(struct kobject *kobj) netdev_put(queue->dev, &queue->dev_tracker); } -static const void *rx_queue_namespace(struct kobject *kobj) +static const void *rx_queue_namespace(const struct kobject *kobj) { struct netdev_rx_queue *queue = to_rx_queue(kobj); struct device *dev = &queue->dev->dev; @@ -1032,7 +1032,7 @@ static const void *rx_queue_namespace(struct kobject *kobj) return ns; } -static void rx_queue_get_ownership(struct kobject *kobj, +static void rx_queue_get_ownership(const struct kobject *kobj, kuid_t *uid, kgid_t *gid) { const struct net *net = rx_queue_namespace(kobj); @@ -1623,7 +1623,7 @@ static void netdev_queue_release(struct kobject *kobj) netdev_put(queue->dev, &queue->dev_tracker); } -static const void *netdev_queue_namespace(struct kobject *kobj) +static const void *netdev_queue_namespace(const struct kobject *kobj) { struct netdev_queue *queue = to_netdev_queue(kobj); struct device *dev = &queue->dev->dev; @@ -1635,7 +1635,7 @@ static const void *netdev_queue_namespace(struct kobject *kobj) return ns; } -static void netdev_queue_get_ownership(struct kobject *kobj, +static void netdev_queue_get_ownership(const struct kobject *kobj, kuid_t *uid, kgid_t *gid) { const struct net *net = netdev_queue_namespace(kobj); @@ -1873,9 +1873,9 @@ const struct kobj_ns_type_operations net_ns_type_operations = { }; EXPORT_SYMBOL_GPL(net_ns_type_operations); -static int netdev_uevent(struct device *d, struct kobj_uevent_env *env) +static int netdev_uevent(const struct device *d, struct kobj_uevent_env *env) { - struct net_device *dev = to_net_dev(d); + const struct net_device *dev = to_net_dev(d); int retval; /* pass interface to uevent. */ @@ -1910,16 +1910,16 @@ static void netdev_release(struct device *d) netdev_freemem(dev); } -static const void *net_namespace(struct device *d) +static const void *net_namespace(const struct device *d) { - struct net_device *dev = to_net_dev(d); + const struct net_device *dev = to_net_dev(d); return dev_net(dev); } -static void net_get_ownership(struct device *d, kuid_t *uid, kgid_t *gid) +static void net_get_ownership(const struct device *d, kuid_t *uid, kgid_t *gid) { - struct net_device *dev = to_net_dev(d); + const struct net_device *dev = to_net_dev(d); const struct net *net = dev_net(dev); net_ns_get_ownership(net, uid, gid); diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index f64654df71a2..078a0a420c8a 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -137,12 +137,12 @@ static int ops_init(const struct pernet_operations *ops, struct net *net) return 0; if (ops->id && ops->size) { -cleanup: ng = rcu_dereference_protected(net->gen, lockdep_is_held(&pernet_ops_rwsem)); ng->ptr[*ops->id] = NULL; } +cleanup: kfree(data); out: @@ -316,6 +316,7 @@ static __net_init int setup_net(struct net *net, struct user_namespace *user_ns) refcount_set(&net->ns.count, 1); ref_tracker_dir_init(&net->refcnt_tracker, 128); + ref_tracker_dir_init(&net->notrefcnt_tracker, 128); refcount_set(&net->passive, 1); get_random_bytes(&net->hash_mix, sizeof(u32)); @@ -436,6 +437,10 @@ static void net_free(struct net *net) { if (refcount_dec_and_test(&net->passive)) { kfree(rcu_access_pointer(net->gen)); + + /* There should not be any trackers left there. */ + ref_tracker_dir_exit(&net->notrefcnt_tracker); + kmem_cache_free(net_cachep, net); } } diff --git a/net/core/of_net.c b/net/core/of_net.c index f1a9bf7578e7..55d3fe229269 100644 --- a/net/core/of_net.c +++ b/net/core/of_net.c @@ -57,7 +57,7 @@ static int of_get_mac_addr(struct device_node *np, const char *name, u8 *addr) return -ENODEV; } -static int of_get_mac_addr_nvmem(struct device_node *np, u8 *addr) +int of_get_mac_address_nvmem(struct device_node *np, u8 *addr) { struct platform_device *pdev = of_find_device_by_node(np); struct nvmem_cell *cell; @@ -94,6 +94,7 @@ static int of_get_mac_addr_nvmem(struct device_node *np, u8 *addr) return 0; } +EXPORT_SYMBOL(of_get_mac_address_nvmem); /** * of_get_mac_address() @@ -140,7 +141,7 @@ int of_get_mac_address(struct device_node *np, u8 *addr) if (!ret) return 0; - return of_get_mac_addr_nvmem(np, addr); + return of_get_mac_address_nvmem(np, addr); } EXPORT_SYMBOL(of_get_mac_address); diff --git a/net/core/pktgen.c b/net/core/pktgen.c index c3763056c554..760238196db1 100644 --- a/net/core/pktgen.c +++ b/net/core/pktgen.c @@ -2324,7 +2324,7 @@ static inline int f_pick(struct pktgen_dev *pkt_dev) pkt_dev->curfl = 0; /*reset */ } } else { - flow = prandom_u32_max(pkt_dev->cflows); + flow = get_random_u32_below(pkt_dev->cflows); pkt_dev->curfl = flow; if (pkt_dev->flows[flow].count > pkt_dev->lflow) { @@ -2380,9 +2380,8 @@ static void set_cur_queue_map(struct pktgen_dev *pkt_dev) else if (pkt_dev->queue_map_min <= pkt_dev->queue_map_max) { __u16 t; if (pkt_dev->flags & F_QUEUE_MAP_RND) { - t = prandom_u32_max(pkt_dev->queue_map_max - - pkt_dev->queue_map_min + 1) + - pkt_dev->queue_map_min; + t = get_random_u32_inclusive(pkt_dev->queue_map_min, + pkt_dev->queue_map_max); } else { t = pkt_dev->cur_queue_map + 1; if (t > pkt_dev->queue_map_max) @@ -2411,7 +2410,7 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) __u32 tmp; if (pkt_dev->flags & F_MACSRC_RND) - mc = prandom_u32_max(pkt_dev->src_mac_count); + mc = get_random_u32_below(pkt_dev->src_mac_count); else { mc = pkt_dev->cur_src_mac_offset++; if (pkt_dev->cur_src_mac_offset >= @@ -2437,7 +2436,7 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) __u32 tmp; if (pkt_dev->flags & F_MACDST_RND) - mc = prandom_u32_max(pkt_dev->dst_mac_count); + mc = get_random_u32_below(pkt_dev->dst_mac_count); else { mc = pkt_dev->cur_dst_mac_offset++; @@ -2469,18 +2468,17 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) } if ((pkt_dev->flags & F_VID_RND) && (pkt_dev->vlan_id != 0xffff)) { - pkt_dev->vlan_id = prandom_u32_max(4096); + pkt_dev->vlan_id = get_random_u32_below(4096); } if ((pkt_dev->flags & F_SVID_RND) && (pkt_dev->svlan_id != 0xffff)) { - pkt_dev->svlan_id = prandom_u32_max(4096); + pkt_dev->svlan_id = get_random_u32_below(4096); } if (pkt_dev->udp_src_min < pkt_dev->udp_src_max) { if (pkt_dev->flags & F_UDPSRC_RND) - pkt_dev->cur_udp_src = prandom_u32_max( - pkt_dev->udp_src_max - pkt_dev->udp_src_min) + - pkt_dev->udp_src_min; + pkt_dev->cur_udp_src = get_random_u32_inclusive(pkt_dev->udp_src_min, + pkt_dev->udp_src_max - 1); else { pkt_dev->cur_udp_src++; @@ -2491,9 +2489,8 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) if (pkt_dev->udp_dst_min < pkt_dev->udp_dst_max) { if (pkt_dev->flags & F_UDPDST_RND) { - pkt_dev->cur_udp_dst = prandom_u32_max( - pkt_dev->udp_dst_max - pkt_dev->udp_dst_min) + - pkt_dev->udp_dst_min; + pkt_dev->cur_udp_dst = get_random_u32_inclusive(pkt_dev->udp_dst_min, + pkt_dev->udp_dst_max - 1); } else { pkt_dev->cur_udp_dst++; if (pkt_dev->cur_udp_dst >= pkt_dev->udp_dst_max) @@ -2508,7 +2505,7 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) if (imn < imx) { __u32 t; if (pkt_dev->flags & F_IPSRC_RND) - t = prandom_u32_max(imx - imn) + imn; + t = get_random_u32_inclusive(imn, imx - 1); else { t = ntohl(pkt_dev->cur_saddr); t++; @@ -2530,8 +2527,7 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) if (pkt_dev->flags & F_IPDST_RND) { do { - t = prandom_u32_max(imx - imn) + - imn; + t = get_random_u32_inclusive(imn, imx - 1); s = htonl(t); } while (ipv4_is_loopback(s) || ipv4_is_multicast(s) || @@ -2578,9 +2574,8 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) if (pkt_dev->min_pkt_size < pkt_dev->max_pkt_size) { __u32 t; if (pkt_dev->flags & F_TXSIZE_RND) { - t = prandom_u32_max(pkt_dev->max_pkt_size - - pkt_dev->min_pkt_size) + - pkt_dev->min_pkt_size; + t = get_random_u32_inclusive(pkt_dev->min_pkt_size, + pkt_dev->max_pkt_size - 1); } else { t = pkt_dev->cur_pkt_size + 1; if (t > pkt_dev->max_pkt_size) @@ -2589,7 +2584,7 @@ static void mod_cur_headers(struct pktgen_dev *pkt_dev) pkt_dev->cur_pkt_size = t; } else if (pkt_dev->n_imix_entries > 0) { struct imix_pkt *entry; - __u32 t = prandom_u32_max(IMIX_PRECISION); + __u32 t = get_random_u32_below(IMIX_PRECISION); __u8 entry_index = pkt_dev->imix_distribution[t]; entry = &pkt_dev->imix_entries[entry_index]; diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c index 74864dc46a7e..64289bc98887 100644 --- a/net/core/rtnetlink.c +++ b/net/core/rtnetlink.c @@ -53,6 +53,7 @@ #include <net/fib_rules.h> #include <net/rtnetlink.h> #include <net/net_namespace.h> +#include <net/devlink.h> #include "dev.h" @@ -760,7 +761,7 @@ int rtnl_unicast(struct sk_buff *skb, struct net *net, u32 pid) EXPORT_SYMBOL(rtnl_unicast); void rtnl_notify(struct sk_buff *skb, struct net *net, u32 pid, u32 group, - struct nlmsghdr *nlh, gfp_t flags) + const struct nlmsghdr *nlh, gfp_t flags) { struct sock *rtnl = net->rtnl; @@ -1038,6 +1039,16 @@ static size_t rtnl_proto_down_size(const struct net_device *dev) return size; } +static size_t rtnl_devlink_port_size(const struct net_device *dev) +{ + size_t size = nla_total_size(0); /* nest IFLA_DEVLINK_PORT */ + + if (dev->devlink_port) + size += devlink_nl_port_handle_size(dev->devlink_port); + + return size; +} + static noinline size_t if_nlmsg_size(const struct net_device *dev, u32 ext_filter_mask) { @@ -1091,6 +1102,7 @@ static noinline size_t if_nlmsg_size(const struct net_device *dev, + nla_total_size(4) /* IFLA_MAX_MTU */ + rtnl_prop_list_size(dev) + nla_total_size(MAX_ADDR_LEN) /* IFLA_PERM_ADDRESS */ + + rtnl_devlink_port_size(dev) + 0; } @@ -1728,6 +1740,30 @@ nla_put_failure: return -EMSGSIZE; } +static int rtnl_fill_devlink_port(struct sk_buff *skb, + const struct net_device *dev) +{ + struct nlattr *devlink_port_nest; + int ret; + + devlink_port_nest = nla_nest_start(skb, IFLA_DEVLINK_PORT); + if (!devlink_port_nest) + return -EMSGSIZE; + + if (dev->devlink_port) { + ret = devlink_nl_port_handle_fill(skb, dev->devlink_port); + if (ret < 0) + goto nest_cancel; + } + + nla_nest_end(skb, devlink_port_nest); + return 0; + +nest_cancel: + nla_nest_cancel(skb, devlink_port_nest); + return ret; +} + static int rtnl_fill_ifinfo(struct sk_buff *skb, struct net_device *dev, struct net *src_net, int type, u32 pid, u32 seq, u32 change, @@ -1865,6 +1901,9 @@ static int rtnl_fill_ifinfo(struct sk_buff *skb, dev->dev.parent->bus->name)) goto nla_put_failure; + if (rtnl_fill_devlink_port(skb, dev)) + goto nla_put_failure; + nlmsg_end(skb, nlh); return 0; @@ -3110,7 +3149,7 @@ static int rtnl_group_dellink(const struct net *net, int group) return 0; } -int rtnl_delete_link(struct net_device *dev) +int rtnl_delete_link(struct net_device *dev, u32 portid, const struct nlmsghdr *nlh) { const struct rtnl_link_ops *ops; LIST_HEAD(list_kill); @@ -3120,7 +3159,7 @@ int rtnl_delete_link(struct net_device *dev) return -EOPNOTSUPP; ops->dellink(dev, &list_kill); - unregister_netdevice_many(&list_kill); + unregister_netdevice_many_notify(&list_kill, portid, nlh); return 0; } @@ -3130,6 +3169,7 @@ static int rtnl_dellink(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { struct net *net = sock_net(skb->sk); + u32 portid = NETLINK_CB(skb).portid; struct net *tgt_net = net; struct net_device *dev = NULL; struct ifinfomsg *ifm; @@ -3171,7 +3211,7 @@ static int rtnl_dellink(struct sk_buff *skb, struct nlmsghdr *nlh, goto out; } - err = rtnl_delete_link(dev); + err = rtnl_delete_link(dev, portid, nlh); out: if (netnsid >= 0) @@ -3180,7 +3220,8 @@ out: return err; } -int rtnl_configure_link(struct net_device *dev, const struct ifinfomsg *ifm) +int rtnl_configure_link(struct net_device *dev, const struct ifinfomsg *ifm, + u32 portid, const struct nlmsghdr *nlh) { unsigned int old_flags; int err; @@ -3194,10 +3235,10 @@ int rtnl_configure_link(struct net_device *dev, const struct ifinfomsg *ifm) } if (dev->rtnl_link_state == RTNL_LINK_INITIALIZED) { - __dev_notify_flags(dev, old_flags, (old_flags ^ dev->flags)); + __dev_notify_flags(dev, old_flags, (old_flags ^ dev->flags), portid, nlh); } else { dev->rtnl_link_state = RTNL_LINK_INITIALIZED; - __dev_notify_flags(dev, old_flags, ~0U); + __dev_notify_flags(dev, old_flags, ~0U, portid, nlh); } return 0; } @@ -3311,11 +3352,13 @@ static int rtnl_group_changelink(const struct sk_buff *skb, static int rtnl_newlink_create(struct sk_buff *skb, struct ifinfomsg *ifm, const struct rtnl_link_ops *ops, + const struct nlmsghdr *nlh, struct nlattr **tb, struct nlattr **data, struct netlink_ext_ack *extack) { unsigned char name_assign_type = NET_NAME_USER; struct net *net = sock_net(skb->sk); + u32 portid = NETLINK_CB(skb).portid; struct net *dest_net, *link_net; struct net_device *dev; char ifname[IFNAMSIZ]; @@ -3369,7 +3412,7 @@ static int rtnl_newlink_create(struct sk_buff *skb, struct ifinfomsg *ifm, goto out; } - err = rtnl_configure_link(dev, ifm); + err = rtnl_configure_link(dev, ifm, portid, nlh); if (err < 0) goto out_unregister; if (link_net) { @@ -3578,7 +3621,7 @@ replay: return -EOPNOTSUPP; } - return rtnl_newlink_create(skb, ifm, ops, tb, data, extack); + return rtnl_newlink_create(skb, ifm, ops, nlh, tb, data, extack); } static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh, @@ -3896,7 +3939,7 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb) struct sk_buff *rtmsg_ifinfo_build_skb(int type, struct net_device *dev, unsigned int change, u32 event, gfp_t flags, int *new_nsid, - int new_ifindex) + int new_ifindex, u32 portid, u32 seq) { struct net *net = dev_net(dev); struct sk_buff *skb; @@ -3907,7 +3950,7 @@ struct sk_buff *rtmsg_ifinfo_build_skb(int type, struct net_device *dev, goto errout; err = rtnl_fill_ifinfo(skb, dev, dev_net(dev), - type, 0, 0, change, 0, 0, event, + type, portid, seq, change, 0, 0, event, new_nsid, new_ifindex, -1, flags); if (err < 0) { /* -EMSGSIZE implies BUG in if_nlmsg_size() */ @@ -3922,16 +3965,18 @@ errout: return NULL; } -void rtmsg_ifinfo_send(struct sk_buff *skb, struct net_device *dev, gfp_t flags) +void rtmsg_ifinfo_send(struct sk_buff *skb, struct net_device *dev, gfp_t flags, + u32 portid, const struct nlmsghdr *nlh) { struct net *net = dev_net(dev); - rtnl_notify(skb, net, 0, RTNLGRP_LINK, NULL, flags); + rtnl_notify(skb, net, portid, RTNLGRP_LINK, nlh, flags); } static void rtmsg_ifinfo_event(int type, struct net_device *dev, unsigned int change, u32 event, - gfp_t flags, int *new_nsid, int new_ifindex) + gfp_t flags, int *new_nsid, int new_ifindex, + u32 portid, const struct nlmsghdr *nlh) { struct sk_buff *skb; @@ -3939,23 +3984,23 @@ static void rtmsg_ifinfo_event(int type, struct net_device *dev, return; skb = rtmsg_ifinfo_build_skb(type, dev, change, event, flags, new_nsid, - new_ifindex); + new_ifindex, portid, nlmsg_seq(nlh)); if (skb) - rtmsg_ifinfo_send(skb, dev, flags); + rtmsg_ifinfo_send(skb, dev, flags, portid, nlh); } void rtmsg_ifinfo(int type, struct net_device *dev, unsigned int change, - gfp_t flags) + gfp_t flags, u32 portid, const struct nlmsghdr *nlh) { rtmsg_ifinfo_event(type, dev, change, rtnl_get_event(0), flags, - NULL, 0); + NULL, 0, portid, nlh); } void rtmsg_ifinfo_newnet(int type, struct net_device *dev, unsigned int change, gfp_t flags, int *new_nsid, int new_ifindex) { rtmsg_ifinfo_event(type, dev, change, rtnl_get_event(0), flags, - new_nsid, new_ifindex); + new_nsid, new_ifindex, 0, NULL); } static int nlmsg_populate_fdb_fill(struct sk_buff *skb, @@ -4045,6 +4090,11 @@ int ndo_dflt_fdb_add(struct ndmsg *ndm, return err; } + if (tb[NDA_FLAGS_EXT]) { + netdev_info(dev, "invalid flags given to default FDB implementation\n"); + return err; + } + if (vid) { netdev_info(dev, "vlans aren't supported yet for dev_uc|mc_add()\n"); return err; @@ -6140,7 +6190,7 @@ static int rtnetlink_event(struct notifier_block *this, unsigned long event, voi case NETDEV_CHANGELOWERSTATE: case NETDEV_CHANGE_TX_QUEUE_LEN: rtmsg_ifinfo_event(RTM_NEWLINK, dev, 0, rtnl_get_event(event), - GFP_KERNEL, NULL, 0); + GFP_KERNEL, NULL, 0, 0, NULL); break; default: break; diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 88fa40571d0c..a31ff4d83ecc 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -94,6 +94,7 @@ EXPORT_SYMBOL(sysctl_max_skb_frags); #undef FN #define FN(reason) [SKB_DROP_REASON_##reason] = #reason, const char * const drop_reasons[] = { + [SKB_CONSUMED] = "CONSUMED", DEFINE_DROP_REASON(FN, FN) }; EXPORT_SYMBOL(drop_reasons); @@ -269,12 +270,10 @@ static struct sk_buff *napi_skb_cache_get(void) return skb; } -/* Caller must provide SKB that is memset cleared */ -static void __build_skb_around(struct sk_buff *skb, void *data, - unsigned int frag_size) +static inline void __finalize_skb_around(struct sk_buff *skb, void *data, + unsigned int size) { struct skb_shared_info *shinfo; - unsigned int size = frag_size ? : ksize(data); size -= SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); @@ -296,15 +295,71 @@ static void __build_skb_around(struct sk_buff *skb, void *data, skb_set_kcov_handle(skb, kcov_common_handle()); } +static inline void *__slab_build_skb(struct sk_buff *skb, void *data, + unsigned int *size) +{ + void *resized; + + /* Must find the allocation size (and grow it to match). */ + *size = ksize(data); + /* krealloc() will immediately return "data" when + * "ksize(data)" is requested: it is the existing upper + * bounds. As a result, GFP_ATOMIC will be ignored. Note + * that this "new" pointer needs to be passed back to the + * caller for use so the __alloc_size hinting will be + * tracked correctly. + */ + resized = krealloc(data, *size, GFP_ATOMIC); + WARN_ON_ONCE(resized != data); + return resized; +} + +/* build_skb() variant which can operate on slab buffers. + * Note that this should be used sparingly as slab buffers + * cannot be combined efficiently by GRO! + */ +struct sk_buff *slab_build_skb(void *data) +{ + struct sk_buff *skb; + unsigned int size; + + skb = kmem_cache_alloc(skbuff_head_cache, GFP_ATOMIC); + if (unlikely(!skb)) + return NULL; + + memset(skb, 0, offsetof(struct sk_buff, tail)); + data = __slab_build_skb(skb, data, &size); + __finalize_skb_around(skb, data, size); + + return skb; +} +EXPORT_SYMBOL(slab_build_skb); + +/* Caller must provide SKB that is memset cleared */ +static void __build_skb_around(struct sk_buff *skb, void *data, + unsigned int frag_size) +{ + unsigned int size = frag_size; + + /* frag_size == 0 is considered deprecated now. Callers + * using slab buffer should use slab_build_skb() instead. + */ + if (WARN_ONCE(size == 0, "Use slab_build_skb() instead")) + data = __slab_build_skb(skb, data, &size); + + __finalize_skb_around(skb, data, size); +} + /** * __build_skb - build a network buffer * @data: data buffer provided by caller - * @frag_size: size of data, or 0 if head was kmalloced + * @frag_size: size of data (must not be 0) * * Allocate a new &sk_buff. Caller provides space holding head and - * skb_shared_info. @data must have been allocated by kmalloc() only if - * @frag_size is 0, otherwise data should come from the page allocator - * or vmalloc() + * skb_shared_info. @data must have been allocated from the page + * allocator or vmalloc(). (A @frag_size of 0 to indicate a kmalloc() + * allocation is deprecated, and callers should use slab_build_skb() + * instead.) * The return is the new skb buffer. * On a failure the return is %NULL, and @data is not freed. * Notes : @@ -506,14 +561,14 @@ struct sk_buff *__alloc_skb(unsigned int size, gfp_t gfp_mask, */ size = SKB_DATA_ALIGN(size); size += SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); - data = kmalloc_reserve(size, gfp_mask, node, &pfmemalloc); + osize = kmalloc_size_roundup(size); + data = kmalloc_reserve(osize, gfp_mask, node, &pfmemalloc); if (unlikely(!data)) goto nodata; - /* kmalloc(size) might give us more room than requested. + /* kmalloc_size_roundup() might give us more room than requested. * Put skb_shared_info exactly at the end of allocated zone, * to allow max possible filling before reallocation. */ - osize = ksize(data); size = SKB_WITH_OVERHEAD(osize); prefetchw(data + size); @@ -748,6 +803,13 @@ static void skb_clone_fraglist(struct sk_buff *skb) skb_get(list); } +static bool skb_pp_recycle(struct sk_buff *skb, void *data) +{ + if (!IS_ENABLED(CONFIG_PAGE_POOL) || !skb->pp_recycle) + return false; + return page_pool_return_skb_page(virt_to_page(data)); +} + static void skb_free_head(struct sk_buff *skb) { unsigned char *head = skb->head; @@ -761,7 +823,7 @@ static void skb_free_head(struct sk_buff *skb) } } -static void skb_release_data(struct sk_buff *skb) +static void skb_release_data(struct sk_buff *skb, enum skb_drop_reason reason) { struct skb_shared_info *shinfo = skb_shinfo(skb); int i; @@ -784,7 +846,7 @@ static void skb_release_data(struct sk_buff *skb) free_head: if (shinfo->frag_list) - kfree_skb_list(shinfo->frag_list); + kfree_skb_list_reason(shinfo->frag_list, reason); skb_free_head(skb); exit: @@ -847,11 +909,11 @@ void skb_release_head_state(struct sk_buff *skb) } /* Free everything but the sk_buff shell. */ -static void skb_release_all(struct sk_buff *skb) +static void skb_release_all(struct sk_buff *skb, enum skb_drop_reason reason) { skb_release_head_state(skb); if (likely(skb->head)) - skb_release_data(skb); + skb_release_data(skb, reason); } /** @@ -865,7 +927,7 @@ static void skb_release_all(struct sk_buff *skb) void __kfree_skb(struct sk_buff *skb) { - skb_release_all(skb); + skb_release_all(skb, SKB_DROP_REASON_NOT_SPECIFIED); kfree_skbmem(skb); } EXPORT_SYMBOL(__kfree_skb); @@ -887,7 +949,10 @@ kfree_skb_reason(struct sk_buff *skb, enum skb_drop_reason reason) DEBUG_NET_WARN_ON_ONCE(reason <= 0 || reason >= SKB_DROP_REASON_MAX); - trace_kfree_skb(skb, __builtin_return_address(0), reason); + if (reason == SKB_CONSUMED) + trace_consume_skb(skb); + else + trace_kfree_skb(skb, __builtin_return_address(0), reason); __kfree_skb(skb); } EXPORT_SYMBOL(kfree_skb_reason); @@ -1045,7 +1110,7 @@ EXPORT_SYMBOL(consume_skb); void __consume_stateless_skb(struct sk_buff *skb) { trace_consume_skb(skb); - skb_release_data(skb); + skb_release_data(skb, SKB_CONSUMED); kfree_skbmem(skb); } @@ -1070,7 +1135,7 @@ static void napi_skb_cache_put(struct sk_buff *skb) void __kfree_skb_defer(struct sk_buff *skb) { - skb_release_all(skb); + skb_release_all(skb, SKB_DROP_REASON_NOT_SPECIFIED); napi_skb_cache_put(skb); } @@ -1108,7 +1173,7 @@ void napi_consume_skb(struct sk_buff *skb, int budget) return; } - skb_release_all(skb); + skb_release_all(skb, SKB_CONSUMED); napi_skb_cache_put(skb); } EXPORT_SYMBOL(napi_consume_skb); @@ -1239,7 +1304,7 @@ EXPORT_SYMBOL_GPL(alloc_skb_for_msg); */ struct sk_buff *skb_morph(struct sk_buff *dst, struct sk_buff *src) { - skb_release_all(dst); + skb_release_all(dst, SKB_CONSUMED); return __skb_clone(dst, src); } EXPORT_SYMBOL_GPL(skb_morph); @@ -1256,13 +1321,12 @@ int mm_account_pinned_pages(struct mmpin *mmp, size_t size) max_pg = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; user = mmp->user ? : current_user(); + old_pg = atomic_long_read(&user->locked_vm); do { - old_pg = atomic_long_read(&user->locked_vm); new_pg = old_pg + num_pg; if (new_pg > max_pg) return -ENOBUFS; - } while (atomic_long_cmpxchg(&user->locked_vm, old_pg, new_pg) != - old_pg); + } while (!atomic_long_try_cmpxchg(&user->locked_vm, &old_pg, new_pg)); if (!mmp->user) { mmp->user = get_uid(user); @@ -1814,10 +1878,11 @@ EXPORT_SYMBOL(__pskb_copy_fclone); int pskb_expand_head(struct sk_buff *skb, int nhead, int ntail, gfp_t gfp_mask) { - int i, osize = skb_end_offset(skb); - int size = osize + nhead + ntail; + unsigned int osize = skb_end_offset(skb); + unsigned int size = osize + nhead + ntail; long off; u8 *data; + int i; BUG_ON(nhead < 0); @@ -1825,15 +1890,16 @@ int pskb_expand_head(struct sk_buff *skb, int nhead, int ntail, skb_zcopy_downgrade_managed(skb); - size = SKB_DATA_ALIGN(size); - if (skb_pfmemalloc(skb)) gfp_mask |= __GFP_MEMALLOC; - data = kmalloc_reserve(size + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)), - gfp_mask, NUMA_NO_NODE, NULL); + + size = SKB_DATA_ALIGN(size); + size += SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + size = kmalloc_size_roundup(size); + data = kmalloc_reserve(size, gfp_mask, NUMA_NO_NODE, NULL); if (!data) goto nodata; - size = SKB_WITH_OVERHEAD(ksize(data)); + size = SKB_WITH_OVERHEAD(size); /* Copy only real data... and, alas, header. This should be * optimized for the cases when header is void. @@ -1860,7 +1926,7 @@ int pskb_expand_head(struct sk_buff *skb, int nhead, int ntail, if (skb_has_frag_list(skb)) skb_clone_fraglist(skb); - skb_release_data(skb); + skb_release_data(skb, SKB_CONSUMED); } else { skb_free_head(skb); } @@ -2416,6 +2482,9 @@ void *__pskb_pull_tail(struct sk_buff *skb, int delta) insp = list; } else { /* Eaten partially. */ + if (skb_is_gso(skb) && !list->head_frag && + skb_headlen(list)) + skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY; if (skb_shared(list)) { /* Sucks! We need to fork list. :-( */ @@ -4031,7 +4100,7 @@ struct sk_buff *skb_segment_list(struct sk_buff *skb, skb_shinfo(skb)->frag_list = NULL; - do { + while (list_skb) { nskb = list_skb; list_skb = list_skb->next; @@ -4077,8 +4146,7 @@ struct sk_buff *skb_segment_list(struct sk_buff *skb, if (skb_needs_linearize(nskb, features) && __skb_linearize(nskb)) goto err_linearize; - - } while (list_skb); + } skb->truesize = skb->truesize - delta_truesize; skb->data_len = skb->data_len - delta_len; @@ -6169,21 +6237,20 @@ static int pskb_carve_inside_header(struct sk_buff *skb, const u32 off, const int headlen, gfp_t gfp_mask) { int i; - int size = skb_end_offset(skb); + unsigned int size = skb_end_offset(skb); int new_hlen = headlen - off; u8 *data; - size = SKB_DATA_ALIGN(size); - if (skb_pfmemalloc(skb)) gfp_mask |= __GFP_MEMALLOC; - data = kmalloc_reserve(size + - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)), - gfp_mask, NUMA_NO_NODE, NULL); + + size = SKB_DATA_ALIGN(size); + size += SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + size = kmalloc_size_roundup(size); + data = kmalloc_reserve(size, gfp_mask, NUMA_NO_NODE, NULL); if (!data) return -ENOMEM; - - size = SKB_WITH_OVERHEAD(ksize(data)); + size = SKB_WITH_OVERHEAD(size); /* Copy real data, and all frags */ skb_copy_from_linear_data_offset(skb, off, data, new_hlen); @@ -6203,7 +6270,7 @@ static int pskb_carve_inside_header(struct sk_buff *skb, const u32 off, skb_frag_ref(skb, i); if (skb_has_frag_list(skb)) skb_clone_fraglist(skb); - skb_release_data(skb); + skb_release_data(skb, SKB_CONSUMED); } else { /* we can reuse existing recount- all we did was * relocate values @@ -6288,22 +6355,21 @@ static int pskb_carve_inside_nonlinear(struct sk_buff *skb, const u32 off, int pos, gfp_t gfp_mask) { int i, k = 0; - int size = skb_end_offset(skb); + unsigned int size = skb_end_offset(skb); u8 *data; const int nfrags = skb_shinfo(skb)->nr_frags; struct skb_shared_info *shinfo; - size = SKB_DATA_ALIGN(size); - if (skb_pfmemalloc(skb)) gfp_mask |= __GFP_MEMALLOC; - data = kmalloc_reserve(size + - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)), - gfp_mask, NUMA_NO_NODE, NULL); + + size = SKB_DATA_ALIGN(size); + size += SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + size = kmalloc_size_roundup(size); + data = kmalloc_reserve(size, gfp_mask, NUMA_NO_NODE, NULL); if (!data) return -ENOMEM; - - size = SKB_WITH_OVERHEAD(ksize(data)); + size = SKB_WITH_OVERHEAD(size); memcpy((struct skb_shared_info *)(data + size), skb_shinfo(skb), offsetof(struct skb_shared_info, frags[0])); @@ -6347,7 +6413,7 @@ static int pskb_carve_inside_nonlinear(struct sk_buff *skb, const u32 off, kfree(data); return -ENOMEM; } - skb_release_data(skb); + skb_release_data(skb, SKB_CONSUMED); skb->head = data; skb->head_frag = 0; @@ -6426,6 +6492,7 @@ void skb_condense(struct sk_buff *skb) */ skb->truesize = SKB_TRUESIZE(skb_end_offset(skb)); } +EXPORT_SYMBOL(skb_condense); #ifdef CONFIG_SKB_EXTENSIONS static void *skb_ext_get_ptr(struct skb_ext *ext, enum skb_ext_id id) diff --git a/net/core/skmsg.c b/net/core/skmsg.c index e6b9ced3eda8..53d0251788aa 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -886,13 +886,16 @@ int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, ret = sk_psock_map_verd(ret, msg->sk_redir); psock->apply_bytes = msg->apply_bytes; if (ret == __SK_REDIRECT) { - if (psock->sk_redir) + if (psock->sk_redir) { sock_put(psock->sk_redir); - psock->sk_redir = msg->sk_redir; - if (!psock->sk_redir) { + psock->sk_redir = NULL; + } + if (!msg->sk_redir) { ret = __SK_DROP; goto out; } + psock->redir_ingress = sk_msg_to_ingress(msg); + psock->sk_redir = msg->sk_redir; sock_hold(psock->sk_redir); } out: diff --git a/net/core/sock.c b/net/core/sock.c index a3ba0358c77c..f954d5893e79 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -901,13 +901,20 @@ int sock_set_timestamping(struct sock *sk, int optname, if (val & ~SOF_TIMESTAMPING_MASK) return -EINVAL; + if (val & SOF_TIMESTAMPING_OPT_ID_TCP && + !(val & SOF_TIMESTAMPING_OPT_ID)) + return -EINVAL; + if (val & SOF_TIMESTAMPING_OPT_ID && !(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID)) { if (sk_is_tcp(sk)) { if ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) return -EINVAL; - atomic_set(&sk->sk_tskey, tcp_sk(sk)->snd_una); + if (val & SOF_TIMESTAMPING_OPT_ID_TCP) + atomic_set(&sk->sk_tskey, tcp_sk(sk)->write_seq); + else + atomic_set(&sk->sk_tskey, tcp_sk(sk)->snd_una); } else { atomic_set(&sk->sk_tskey, 0); } @@ -1436,7 +1443,7 @@ set_sndbuf: break; } case SO_INCOMING_CPU: - WRITE_ONCE(sk->sk_incoming_cpu, val); + reuseport_update_incoming_cpu(sk, val); break; case SO_CNX_ADVICE: @@ -1793,7 +1800,8 @@ int sk_getsockopt(struct sock *sk, int level, int optname, break; case SO_PEERSEC: - return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len); + return security_socket_getpeersec_stream(sock, + optval, optlen, len); case SO_MARK: v.val = sk->sk_mark; @@ -2094,6 +2102,9 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority, if (likely(sk->sk_net_refcnt)) { get_net_track(net, &sk->ns_tracker, priority); sock_inuse_add(net, 1); + } else { + __netns_tracker_alloc(net, &sk->ns_tracker, + false, priority); } sock_net_set(sk, net); @@ -2149,6 +2160,9 @@ static void __sk_destruct(struct rcu_head *head) if (likely(sk->sk_net_refcnt)) put_net_track(sock_net(sk), &sk->ns_tracker); + else + __netns_tracker_free(sock_net(sk), &sk->ns_tracker, false); + sk_prot_free(sk->sk_prot_creator, sk); } @@ -2237,6 +2251,14 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) if (likely(newsk->sk_net_refcnt)) { get_net_track(sock_net(newsk), &newsk->ns_tracker, priority); sock_inuse_add(sock_net(newsk), 1); + } else { + /* Kernel sockets are not elevating the struct net refcount. + * Instead, use a tracker to more easily detect if a layer + * is not properly dismantling its kernel sockets at netns + * destroy time. + */ + __netns_tracker_alloc(sock_net(newsk), &newsk->ns_tracker, + false, priority); } sk_node_init(&newsk->sk_node); sock_lock_init(newsk); @@ -2730,7 +2752,7 @@ failure: } EXPORT_SYMBOL(sock_alloc_send_pskb); -int __sock_cmsg_send(struct sock *sk, struct msghdr *msg, struct cmsghdr *cmsg, +int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg, struct sockcm_cookie *sockc) { u32 tsflags; @@ -2784,7 +2806,7 @@ int sock_cmsg_send(struct sock *sk, struct msghdr *msg, return -EINVAL; if (cmsg->cmsg_level != SOL_SOCKET) continue; - ret = __sock_cmsg_send(sk, msg, cmsg, sockc); + ret = __sock_cmsg_send(sk, cmsg, sockc); if (ret) return ret; } @@ -3368,6 +3390,7 @@ void sock_init_data(struct socket *sock, struct sock *sk) sk->sk_rcvbuf = READ_ONCE(sysctl_rmem_default); sk->sk_sndbuf = READ_ONCE(sysctl_wmem_default); sk->sk_state = TCP_CLOSE; + sk->sk_use_task_frag = true; sk_set_socket(sk, sock); sock_set_flag(sk, SOCK_ZAPPED); diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c index f7cf74cdd3db..b1e29e18d1d6 100644 --- a/net/core/sock_diag.c +++ b/net/core/sock_diag.c @@ -25,14 +25,17 @@ DEFINE_COOKIE(sock_cookie); u64 __sock_gen_cookie(struct sock *sk) { - while (1) { - u64 res = atomic64_read(&sk->sk_cookie); + u64 res = atomic64_read(&sk->sk_cookie); - if (res) - return res; - res = gen_cookie_next(&sock_cookie); - atomic64_cmpxchg(&sk->sk_cookie, 0, res); + if (!res) { + u64 new = gen_cookie_next(&sock_cookie); + + atomic64_cmpxchg(&sk->sk_cookie, res, new); + + /* Another thread might have changed sk_cookie before us. */ + res = atomic64_read(&sk->sk_cookie); } + return res; } int sock_diag_check_cookie(struct sock *sk, const __u32 *cookie) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 81beb16ab1eb..a68a7290a3b2 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -349,11 +349,13 @@ static void sock_map_free(struct bpf_map *map) sk = xchg(psk, NULL); if (sk) { + sock_hold(sk); lock_sock(sk); rcu_read_lock(); sock_map_unref(sk, psk); rcu_read_unlock(); release_sock(sk); + sock_put(sk); } } @@ -1567,15 +1569,16 @@ void sock_map_unhash(struct sock *sk) psock = sk_psock(sk); if (unlikely(!psock)) { rcu_read_unlock(); - if (sk->sk_prot->unhash) - sk->sk_prot->unhash(sk); - return; + saved_unhash = READ_ONCE(sk->sk_prot)->unhash; + } else { + saved_unhash = psock->saved_unhash; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); } - - saved_unhash = psock->saved_unhash; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - saved_unhash(sk); + if (WARN_ON_ONCE(saved_unhash == sock_map_unhash)) + return; + if (saved_unhash) + saved_unhash(sk); } EXPORT_SYMBOL_GPL(sock_map_unhash); @@ -1588,17 +1591,18 @@ void sock_map_destroy(struct sock *sk) psock = sk_psock_get(sk); if (unlikely(!psock)) { rcu_read_unlock(); - if (sk->sk_prot->destroy) - sk->sk_prot->destroy(sk); - return; + saved_destroy = READ_ONCE(sk->sk_prot)->destroy; + } else { + saved_destroy = psock->saved_destroy; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock); + sk_psock_put(sk, psock); } - - saved_destroy = psock->saved_destroy; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - sk_psock_stop(psock); - sk_psock_put(sk, psock); - saved_destroy(sk); + if (WARN_ON_ONCE(saved_destroy == sock_map_destroy)) + return; + if (saved_destroy) + saved_destroy(sk); } EXPORT_SYMBOL_GPL(sock_map_destroy); @@ -1613,16 +1617,21 @@ void sock_map_close(struct sock *sk, long timeout) if (unlikely(!psock)) { rcu_read_unlock(); release_sock(sk); - return sk->sk_prot->close(sk, timeout); + saved_close = READ_ONCE(sk->sk_prot)->close; + } else { + saved_close = psock->saved_close; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock); + release_sock(sk); + cancel_work_sync(&psock->work); + sk_psock_put(sk, psock); } - - saved_close = psock->saved_close; - sock_map_remove_links(sk, psock); - rcu_read_unlock(); - sk_psock_stop(psock); - release_sock(sk); - cancel_work_sync(&psock->work); - sk_psock_put(sk, psock); + /* Make sure we do not recurse. This is a bug. + * Leak the socket instead of crashing on a stack overflow. + */ + if (WARN_ON_ONCE(saved_close == sock_map_close)) + return; saved_close(sk, timeout); } EXPORT_SYMBOL_GPL(sock_map_close); diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c index fb90e1e00773..5a165286e4d8 100644 --- a/net/core/sock_reuseport.c +++ b/net/core/sock_reuseport.c @@ -37,6 +37,70 @@ void reuseport_has_conns_set(struct sock *sk) } EXPORT_SYMBOL(reuseport_has_conns_set); +static void __reuseport_get_incoming_cpu(struct sock_reuseport *reuse) +{ + /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */ + WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu + 1); +} + +static void __reuseport_put_incoming_cpu(struct sock_reuseport *reuse) +{ + /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */ + WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu - 1); +} + +static void reuseport_get_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse) +{ + if (sk->sk_incoming_cpu >= 0) + __reuseport_get_incoming_cpu(reuse); +} + +static void reuseport_put_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse) +{ + if (sk->sk_incoming_cpu >= 0) + __reuseport_put_incoming_cpu(reuse); +} + +void reuseport_update_incoming_cpu(struct sock *sk, int val) +{ + struct sock_reuseport *reuse; + int old_sk_incoming_cpu; + + if (unlikely(!rcu_access_pointer(sk->sk_reuseport_cb))) { + /* Paired with REAE_ONCE() in sk_incoming_cpu_update() + * and compute_score(). + */ + WRITE_ONCE(sk->sk_incoming_cpu, val); + return; + } + + spin_lock_bh(&reuseport_lock); + + /* This must be done under reuseport_lock to avoid a race with + * reuseport_grow(), which accesses sk->sk_incoming_cpu without + * lock_sock() when detaching a shutdown()ed sk. + * + * Paired with READ_ONCE() in reuseport_select_sock_by_hash(). + */ + old_sk_incoming_cpu = sk->sk_incoming_cpu; + WRITE_ONCE(sk->sk_incoming_cpu, val); + + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + + /* reuseport_grow() has detached a closed sk. */ + if (!reuse) + goto out; + + if (old_sk_incoming_cpu < 0 && val >= 0) + __reuseport_get_incoming_cpu(reuse); + else if (old_sk_incoming_cpu >= 0 && val < 0) + __reuseport_put_incoming_cpu(reuse); + +out: + spin_unlock_bh(&reuseport_lock); +} + static int reuseport_sock_index(struct sock *sk, const struct sock_reuseport *reuse, bool closed) @@ -64,6 +128,7 @@ static void __reuseport_add_sock(struct sock *sk, /* paired with smp_rmb() in reuseport_(select|migrate)_sock() */ smp_wmb(); reuse->num_socks++; + reuseport_get_incoming_cpu(sk, reuse); } static bool __reuseport_detach_sock(struct sock *sk, @@ -76,6 +141,7 @@ static bool __reuseport_detach_sock(struct sock *sk, reuse->socks[i] = reuse->socks[reuse->num_socks - 1]; reuse->num_socks--; + reuseport_put_incoming_cpu(sk, reuse); return true; } @@ -86,6 +152,7 @@ static void __reuseport_add_closed_sock(struct sock *sk, reuse->socks[reuse->max_socks - reuse->num_closed_socks - 1] = sk; /* paired with READ_ONCE() in inet_csk_bind_conflict() */ WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks + 1); + reuseport_get_incoming_cpu(sk, reuse); } static bool __reuseport_detach_closed_sock(struct sock *sk, @@ -99,6 +166,7 @@ static bool __reuseport_detach_closed_sock(struct sock *sk, reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks]; /* paired with READ_ONCE() in inet_csk_bind_conflict() */ WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks - 1); + reuseport_put_incoming_cpu(sk, reuse); return true; } @@ -166,6 +234,7 @@ int reuseport_alloc(struct sock *sk, bool bind_inany) reuse->bind_inany = bind_inany; reuse->socks[0] = sk; reuse->num_socks = 1; + reuseport_get_incoming_cpu(sk, reuse); rcu_assign_pointer(sk->sk_reuseport_cb, reuse); out: @@ -209,6 +278,7 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse) more_reuse->reuseport_id = reuse->reuseport_id; more_reuse->bind_inany = reuse->bind_inany; more_reuse->has_conns = reuse->has_conns; + more_reuse->incoming_cpu = reuse->incoming_cpu; memcpy(more_reuse->socks, reuse->socks, reuse->num_socks * sizeof(struct sock *)); @@ -458,18 +528,32 @@ static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks, static struct sock *reuseport_select_sock_by_hash(struct sock_reuseport *reuse, u32 hash, u16 num_socks) { + struct sock *first_valid_sk = NULL; int i, j; i = j = reciprocal_scale(hash, num_socks); - while (reuse->socks[i]->sk_state == TCP_ESTABLISHED) { + do { + struct sock *sk = reuse->socks[i]; + + if (sk->sk_state != TCP_ESTABLISHED) { + /* Paired with WRITE_ONCE() in __reuseport_(get|put)_incoming_cpu(). */ + if (!READ_ONCE(reuse->incoming_cpu)) + return sk; + + /* Paired with WRITE_ONCE() in reuseport_update_incoming_cpu(). */ + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + return sk; + + if (!first_valid_sk) + first_valid_sk = sk; + } + i++; if (i >= num_socks) i = 0; - if (i == j) - return NULL; - } + } while (i != j); - return reuse->socks[i]; + return first_valid_sk; } /** diff --git a/net/core/stream.c b/net/core/stream.c index 75fded8495f5..cd06750dd329 100644 --- a/net/core/stream.c +++ b/net/core/stream.c @@ -123,7 +123,7 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p) DEFINE_WAIT_FUNC(wait, woken_wake_function); if (sk_stream_memory_free(sk)) - current_timeo = vm_wait = prandom_u32_max(HZ / 5) + 2; + current_timeo = vm_wait = get_random_u32_below(HZ / 5) + 2; add_wait_queue(sk_sleep(sk), &wait); @@ -196,6 +196,12 @@ void sk_stream_kill_queues(struct sock *sk) /* First the read buffer. */ __skb_queue_purge(&sk->sk_receive_queue); + /* Next, the error queue. + * We need to use queue lock, because other threads might + * add packets to the queue without socket lock being held. + */ + skb_queue_purge(&sk->sk_error_queue); + /* Next, the write queue. */ WARN_ON_ONCE(!skb_queue_empty(&sk->sk_write_queue)); diff --git a/net/core/tso.c b/net/core/tso.c index 4148f6d48953..e00796e3b146 100644 --- a/net/core/tso.c +++ b/net/core/tso.c @@ -5,14 +5,6 @@ #include <net/tso.h> #include <asm/unaligned.h> -/* Calculate expected number of TX descriptors */ -int tso_count_descs(const struct sk_buff *skb) -{ - /* The Marvell Way */ - return skb_shinfo(skb)->gso_segs * 2 + skb_shinfo(skb)->nr_frags; -} -EXPORT_SYMBOL(tso_count_descs); - void tso_build_hdr(const struct sk_buff *skb, char *hdr, struct tso_t *tso, int size, bool is_last) { diff --git a/net/core/utils.c b/net/core/utils.c index 938495bc1d34..c994e95172ac 100644 --- a/net/core/utils.c +++ b/net/core/utils.c @@ -302,7 +302,7 @@ static int inet4_pton(const char *src, u16 port_num, struct sockaddr_storage *addr) { struct sockaddr_in *addr4 = (struct sockaddr_in *)addr; - int srclen = strlen(src); + size_t srclen = strlen(src); if (srclen > INET_ADDRSTRLEN) return -EINVAL; @@ -322,7 +322,7 @@ static int inet6_pton(struct net *net, const char *src, u16 port_num, { struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)addr; const char *scope_delim; - int srclen = strlen(src); + size_t srclen = strlen(src); if (srclen > INET6_ADDRSTRLEN) return -EINVAL; diff --git a/net/dcb/dcbnl.c b/net/dcb/dcbnl.c index dc4fb699b56c..f9949e051f49 100644 --- a/net/dcb/dcbnl.c +++ b/net/dcb/dcbnl.c @@ -166,6 +166,7 @@ static const struct nla_policy dcbnl_ieee_policy[DCB_ATTR_IEEE_MAX + 1] = { [DCB_ATTR_IEEE_QCN] = {.len = sizeof(struct ieee_qcn)}, [DCB_ATTR_IEEE_QCN_STATS] = {.len = sizeof(struct ieee_qcn_stats)}, [DCB_ATTR_DCB_BUFFER] = {.len = sizeof(struct dcbnl_buffer)}, + [DCB_ATTR_DCB_APP_TRUST_TABLE] = {.type = NLA_NESTED}, }; /* DCB number of traffic classes nested attributes. */ @@ -179,6 +180,38 @@ static const struct nla_policy dcbnl_featcfg_nest[DCB_FEATCFG_ATTR_MAX + 1] = { static LIST_HEAD(dcb_app_list); static DEFINE_SPINLOCK(dcb_lock); +static enum ieee_attrs_app dcbnl_app_attr_type_get(u8 selector) +{ + switch (selector) { + case IEEE_8021QAZ_APP_SEL_ETHERTYPE: + case IEEE_8021QAZ_APP_SEL_STREAM: + case IEEE_8021QAZ_APP_SEL_DGRAM: + case IEEE_8021QAZ_APP_SEL_ANY: + case IEEE_8021QAZ_APP_SEL_DSCP: + return DCB_ATTR_IEEE_APP; + case DCB_APP_SEL_PCP: + return DCB_ATTR_DCB_APP; + default: + return DCB_ATTR_IEEE_APP_UNSPEC; + } +} + +static bool dcbnl_app_attr_type_validate(enum ieee_attrs_app type) +{ + switch (type) { + case DCB_ATTR_IEEE_APP: + case DCB_ATTR_DCB_APP: + return true; + default: + return false; + } +} + +static bool dcbnl_app_selector_validate(enum ieee_attrs_app type, u8 selector) +{ + return dcbnl_app_attr_type_get(selector) == type; +} + static struct sk_buff *dcbnl_newmsg(int type, u8 cmd, u32 port, u32 seq, u32 flags, struct nlmsghdr **nlhp) { @@ -1027,12 +1060,51 @@ nla_put_failure: return err; } +static int dcbnl_getapptrust(struct net_device *netdev, struct sk_buff *skb) +{ + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + enum ieee_attrs_app type; + struct nlattr *apptrust; + int nselectors, err, i; + u8 *selectors; + + selectors = kzalloc(IEEE_8021QAZ_APP_SEL_MAX + 1, GFP_KERNEL); + if (!selectors) + return -ENOMEM; + + err = ops->dcbnl_getapptrust(netdev, selectors, &nselectors); + if (err) { + err = 0; + goto out; + } + + apptrust = nla_nest_start(skb, DCB_ATTR_DCB_APP_TRUST_TABLE); + if (!apptrust) { + err = -EMSGSIZE; + goto out; + } + + for (i = 0; i < nselectors; i++) { + type = dcbnl_app_attr_type_get(selectors[i]); + err = nla_put_u8(skb, type, selectors[i]); + if (err) { + nla_nest_cancel(skb, apptrust); + goto out; + } + } + nla_nest_end(skb, apptrust); + +out: + kfree(selectors); + return err; +} + /* Handle IEEE 802.1Qaz/802.1Qau/802.1Qbb GET commands. */ static int dcbnl_ieee_fill(struct sk_buff *skb, struct net_device *netdev) { + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; struct nlattr *ieee, *app; struct dcb_app_type *itr; - const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; int dcbx; int err; @@ -1116,8 +1188,9 @@ static int dcbnl_ieee_fill(struct sk_buff *skb, struct net_device *netdev) spin_lock_bh(&dcb_lock); list_for_each_entry(itr, &dcb_app_list, list) { if (itr->ifindex == netdev->ifindex) { - err = nla_put(skb, DCB_ATTR_IEEE_APP, sizeof(itr->app), - &itr->app); + enum ieee_attrs_app type = + dcbnl_app_attr_type_get(itr->app.selector); + err = nla_put(skb, type, sizeof(itr->app), &itr->app); if (err) { spin_unlock_bh(&dcb_lock); return -EMSGSIZE; @@ -1133,6 +1206,12 @@ static int dcbnl_ieee_fill(struct sk_buff *skb, struct net_device *netdev) spin_unlock_bh(&dcb_lock); nla_nest_end(skb, app); + if (ops->dcbnl_getapptrust) { + err = dcbnl_getapptrust(netdev, skb); + if (err) + return err; + } + /* get peer info if available */ if (ops->ieee_peer_getets) { struct ieee_ets ets; @@ -1493,9 +1572,10 @@ static int dcbnl_ieee_set(struct net_device *netdev, struct nlmsghdr *nlh, int rem; nla_for_each_nested(attr, ieee[DCB_ATTR_IEEE_APP_TABLE], rem) { + enum ieee_attrs_app type = nla_type(attr); struct dcb_app *app_data; - if (nla_type(attr) != DCB_ATTR_IEEE_APP) + if (!dcbnl_app_attr_type_validate(type)) continue; if (nla_len(attr) < sizeof(struct dcb_app)) { @@ -1504,6 +1584,13 @@ static int dcbnl_ieee_set(struct net_device *netdev, struct nlmsghdr *nlh, } app_data = nla_data(attr); + + if (!dcbnl_app_selector_validate(type, + app_data->selector)) { + err = -EINVAL; + goto err; + } + if (ops->ieee_setapp) err = ops->ieee_setapp(netdev, app_data); else @@ -1513,6 +1600,53 @@ static int dcbnl_ieee_set(struct net_device *netdev, struct nlmsghdr *nlh, } } + if (ieee[DCB_ATTR_DCB_APP_TRUST_TABLE]) { + u8 selectors[IEEE_8021QAZ_APP_SEL_MAX + 1] = {0}; + struct nlattr *attr; + int nselectors = 0; + int rem; + + if (!ops->dcbnl_setapptrust) { + err = -EOPNOTSUPP; + goto err; + } + + nla_for_each_nested(attr, ieee[DCB_ATTR_DCB_APP_TRUST_TABLE], + rem) { + enum ieee_attrs_app type = nla_type(attr); + u8 selector; + int i; + + if (!dcbnl_app_attr_type_validate(type) || + nla_len(attr) != 1 || + nselectors >= sizeof(selectors)) { + err = -EINVAL; + goto err; + } + + selector = nla_get_u8(attr); + + if (!dcbnl_app_selector_validate(type, selector)) { + err = -EINVAL; + goto err; + } + + /* Duplicate selector ? */ + for (i = 0; i < nselectors; i++) { + if (selectors[i] == selector) { + err = -EINVAL; + goto err; + } + } + + selectors[nselectors++] = selector; + } + + err = ops->dcbnl_setapptrust(netdev, selectors, nselectors); + if (err) + goto err; + } + err: err = nla_put_u8(skb, DCB_ATTR_IEEE, err); dcbnl_ieee_notify(netdev, RTM_SETDCB, DCB_CMD_IEEE_SET, seq, 0); @@ -1554,11 +1688,20 @@ static int dcbnl_ieee_del(struct net_device *netdev, struct nlmsghdr *nlh, int rem; nla_for_each_nested(attr, ieee[DCB_ATTR_IEEE_APP_TABLE], rem) { + enum ieee_attrs_app type = nla_type(attr); struct dcb_app *app_data; - if (nla_type(attr) != DCB_ATTR_IEEE_APP) + if (!dcbnl_app_attr_type_validate(type)) continue; + app_data = nla_data(attr); + + if (!dcbnl_app_selector_validate(type, + app_data->selector)) { + err = -EINVAL; + goto err; + } + if (ops->ieee_delapp) err = ops->ieee_delapp(netdev, app_data); else diff --git a/net/dccp/dccp.h b/net/dccp/dccp.h index 7dfc00c9fb32..9ddc3a9e89e4 100644 --- a/net/dccp/dccp.h +++ b/net/dccp/dccp.h @@ -278,6 +278,7 @@ int dccp_rcv_state_process(struct sock *sk, struct sk_buff *skb, int dccp_rcv_established(struct sock *sk, struct sk_buff *skb, const struct dccp_hdr *dh, const unsigned int len); +void dccp_destruct_common(struct sock *sk); int dccp_init_sock(struct sock *sk, const __u8 ctl_sock_initialized); void dccp_destroy_sock(struct sock *sk); diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c index 713b7b8dad7e..b780827f5e0a 100644 --- a/net/dccp/ipv4.c +++ b/net/dccp/ipv4.c @@ -45,11 +45,10 @@ static unsigned int dccp_v4_pernet_id __read_mostly; int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; - struct inet_bind_hashbucket *prev_addr_hashbucket = NULL; - __be32 daddr, nexthop, prev_sk_rcv_saddr; struct inet_sock *inet = inet_sk(sk); struct dccp_sock *dp = dccp_sk(sk); __be16 orig_sport, orig_dport; + __be32 daddr, nexthop; struct flowi4 *fl4; struct rtable *rt; int err; @@ -91,26 +90,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) daddr = fl4->daddr; if (inet->inet_saddr == 0) { - if (inet_csk(sk)->icsk_bind2_hash) { - prev_addr_hashbucket = - inet_bhashfn_portaddr(&dccp_hashinfo, sk, - sock_net(sk), - inet->inet_num); - prev_sk_rcv_saddr = sk->sk_rcv_saddr; - } - inet->inet_saddr = fl4->saddr; - } - - sk_rcv_saddr_set(sk, inet->inet_saddr); - - if (prev_addr_hashbucket) { - err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); + err = inet_bhash2_update_saddr(sk, &fl4->saddr, AF_INET); if (err) { - inet->inet_saddr = 0; - sk_rcv_saddr_set(sk, prev_sk_rcv_saddr); ip_rt_put(rt); return err; } + } else { + sk_rcv_saddr_set(sk, inet->inet_saddr); } inet->inet_dport = usin->sin_port; @@ -157,6 +143,7 @@ failure: * This unhashes the socket and releases the local port, if necessary. */ dccp_set_state(sk, DCCP_CLOSED); + inet_bhash2_reset_saddr(sk); ip_rt_put(rt); sk->sk_route_caps = 0; inet->inet_dport = 0; diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c index e57b43006074..4260fe466993 100644 --- a/net/dccp/ipv6.c +++ b/net/dccp/ipv6.c @@ -934,26 +934,11 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr, } if (saddr == NULL) { - struct inet_bind_hashbucket *prev_addr_hashbucket = NULL; - struct in6_addr prev_v6_rcv_saddr; - - if (icsk->icsk_bind2_hash) { - prev_addr_hashbucket = inet_bhashfn_portaddr(&dccp_hashinfo, - sk, sock_net(sk), - inet->inet_num); - prev_v6_rcv_saddr = sk->sk_v6_rcv_saddr; - } - saddr = &fl6.saddr; - sk->sk_v6_rcv_saddr = *saddr; - - if (prev_addr_hashbucket) { - err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); - if (err) { - sk->sk_v6_rcv_saddr = prev_v6_rcv_saddr; - goto failure; - } - } + + err = inet_bhash2_update_saddr(sk, saddr, AF_INET6); + if (err) + goto failure; } /* set the source address */ @@ -985,6 +970,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr, late_failure: dccp_set_state(sk, DCCP_CLOSED); + inet_bhash2_reset_saddr(sk); __sk_dst_reset(sk); failure: inet->inet_dport = 0; @@ -1021,6 +1007,12 @@ static const struct inet_connection_sock_af_ops dccp_ipv6_mapped = { .sockaddr_len = sizeof(struct sockaddr_in6), }; +static void dccp_v6_sk_destruct(struct sock *sk) +{ + dccp_destruct_common(sk); + inet6_sock_destruct(sk); +} + /* NOTE: A lot of things set to zero explicitly by call to * sk_alloc() so need not be done here. */ @@ -1033,17 +1025,12 @@ static int dccp_v6_init_sock(struct sock *sk) if (unlikely(!dccp_v6_ctl_sock_initialized)) dccp_v6_ctl_sock_initialized = 1; inet_csk(sk)->icsk_af_ops = &dccp_ipv6_af_ops; + sk->sk_destruct = dccp_v6_sk_destruct; } return err; } -static void dccp_v6_destroy_sock(struct sock *sk) -{ - dccp_destroy_sock(sk); - inet6_destroy_sock(sk); -} - static struct timewait_sock_ops dccp6_timewait_sock_ops = { .twsk_obj_size = sizeof(struct dccp6_timewait_sock), }; @@ -1066,7 +1053,7 @@ static struct proto dccp_v6_prot = { .accept = inet_csk_accept, .get_port = inet_csk_get_port, .shutdown = dccp_shutdown, - .destroy = dccp_v6_destroy_sock, + .destroy = dccp_destroy_sock, .orphan_count = &dccp_orphan_count, .max_header = MAX_DCCP_HEADER, .obj_size = sizeof(struct dccp6_sock), diff --git a/net/dccp/proto.c b/net/dccp/proto.c index c548ca3e9b0e..a06b5641287a 100644 --- a/net/dccp/proto.c +++ b/net/dccp/proto.c @@ -171,12 +171,18 @@ const char *dccp_packet_name(const int type) EXPORT_SYMBOL_GPL(dccp_packet_name); -static void dccp_sk_destruct(struct sock *sk) +void dccp_destruct_common(struct sock *sk) { struct dccp_sock *dp = dccp_sk(sk); ccid_hc_tx_delete(dp->dccps_hc_tx_ccid, sk); dp->dccps_hc_tx_ccid = NULL; +} +EXPORT_SYMBOL_GPL(dccp_destruct_common); + +static void dccp_sk_destruct(struct sock *sk) +{ + dccp_destruct_common(sk); inet_sock_destruct(sk); } @@ -279,8 +285,7 @@ int dccp_disconnect(struct sock *sk, int flags) inet->inet_dport = 0; - if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) - inet_reset_saddr(sk); + inet_bhash2_reset_saddr(sk); sk->sk_shutdown = 0; sock_reset_flag(sk, SOCK_DONE); diff --git a/net/dns_resolver/dns_key.c b/net/dns_resolver/dns_key.c index 3aced951d5ab..01e54b46ae0b 100644 --- a/net/dns_resolver/dns_key.c +++ b/net/dns_resolver/dns_key.c @@ -337,7 +337,7 @@ static int __init init_dns_resolver(void) * this is used to prevent malicious redirections from being installed * with add_key(). */ - cred = prepare_kernel_cred(NULL); + cred = prepare_kernel_cred(&init_task); if (!cred) return -ENOMEM; diff --git a/net/dsa/Kconfig b/net/dsa/Kconfig index 3eef72ce99a4..8e698bea99a3 100644 --- a/net/dsa/Kconfig +++ b/net/dsa/Kconfig @@ -18,6 +18,12 @@ if NET_DSA # Drivers must select the appropriate tagging format(s) +config NET_DSA_TAG_NONE + tristate "No-op tag driver" + help + Say Y or M if you want to enable support for switches which don't tag + frames over the CPU port. + config NET_DSA_TAG_AR9331 tristate "Tag driver for Atheros AR9331 SoC with built-in switch" help diff --git a/net/dsa/Makefile b/net/dsa/Makefile index bf57ef3bce2a..cc7e93a562fe 100644 --- a/net/dsa/Makefile +++ b/net/dsa/Makefile @@ -2,13 +2,14 @@ # the core obj-$(CONFIG_NET_DSA) += dsa_core.o dsa_core-y += \ + devlink.o \ dsa.o \ - dsa2.o \ master.o \ netlink.o \ port.o \ slave.o \ switch.o \ + tag.o \ tag_8021q.o # tagging formats @@ -20,6 +21,7 @@ obj-$(CONFIG_NET_DSA_TAG_HELLCREEK) += tag_hellcreek.o obj-$(CONFIG_NET_DSA_TAG_KSZ) += tag_ksz.o obj-$(CONFIG_NET_DSA_TAG_LAN9303) += tag_lan9303.o obj-$(CONFIG_NET_DSA_TAG_MTK) += tag_mtk.o +obj-$(CONFIG_NET_DSA_TAG_NONE) += tag_none.o obj-$(CONFIG_NET_DSA_TAG_OCELOT) += tag_ocelot.o obj-$(CONFIG_NET_DSA_TAG_OCELOT_8021Q) += tag_ocelot_8021q.o obj-$(CONFIG_NET_DSA_TAG_QCA) += tag_qca.o diff --git a/net/dsa/devlink.c b/net/dsa/devlink.c new file mode 100644 index 000000000000..431bf52290a1 --- /dev/null +++ b/net/dsa/devlink.c @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * DSA devlink handling + */ + +#include <net/dsa.h> +#include <net/devlink.h> + +#include "devlink.h" + +static int dsa_devlink_info_get(struct devlink *dl, + struct devlink_info_req *req, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (ds->ops->devlink_info_get) + return ds->ops->devlink_info_get(ds, req, extack); + + return -EOPNOTSUPP; +} + +static int dsa_devlink_sb_pool_get(struct devlink *dl, + unsigned int sb_index, u16 pool_index, + struct devlink_sb_pool_info *pool_info) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_pool_get(ds, sb_index, pool_index, + pool_info); +} + +static int dsa_devlink_sb_pool_set(struct devlink *dl, unsigned int sb_index, + u16 pool_index, u32 size, + enum devlink_sb_threshold_type threshold_type, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_pool_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_pool_set(ds, sb_index, pool_index, size, + threshold_type, extack); +} + +static int dsa_devlink_sb_port_pool_get(struct devlink_port *dlp, + unsigned int sb_index, u16 pool_index, + u32 *p_threshold) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_port_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_port_pool_get(ds, port, sb_index, + pool_index, p_threshold); +} + +static int dsa_devlink_sb_port_pool_set(struct devlink_port *dlp, + unsigned int sb_index, u16 pool_index, + u32 threshold, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_port_pool_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_port_pool_set(ds, port, sb_index, + pool_index, threshold, extack); +} + +static int +dsa_devlink_sb_tc_pool_bind_get(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u16 *p_pool_index, u32 *p_threshold) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_tc_pool_bind_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_tc_pool_bind_get(ds, port, sb_index, + tc_index, pool_type, + p_pool_index, p_threshold); +} + +static int +dsa_devlink_sb_tc_pool_bind_set(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u16 pool_index, u32 threshold, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_tc_pool_bind_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_tc_pool_bind_set(ds, port, sb_index, + tc_index, pool_type, + pool_index, threshold, + extack); +} + +static int dsa_devlink_sb_occ_snapshot(struct devlink *dl, + unsigned int sb_index) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_occ_snapshot) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_snapshot(ds, sb_index); +} + +static int dsa_devlink_sb_occ_max_clear(struct devlink *dl, + unsigned int sb_index) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_occ_max_clear) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_max_clear(ds, sb_index); +} + +static int dsa_devlink_sb_occ_port_pool_get(struct devlink_port *dlp, + unsigned int sb_index, + u16 pool_index, u32 *p_cur, + u32 *p_max) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_occ_port_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_port_pool_get(ds, port, sb_index, + pool_index, p_cur, p_max); +} + +static int +dsa_devlink_sb_occ_tc_port_bind_get(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u32 *p_cur, u32 *p_max) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_occ_tc_port_bind_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_tc_port_bind_get(ds, port, + sb_index, tc_index, + pool_type, p_cur, + p_max); +} + +static const struct devlink_ops dsa_devlink_ops = { + .info_get = dsa_devlink_info_get, + .sb_pool_get = dsa_devlink_sb_pool_get, + .sb_pool_set = dsa_devlink_sb_pool_set, + .sb_port_pool_get = dsa_devlink_sb_port_pool_get, + .sb_port_pool_set = dsa_devlink_sb_port_pool_set, + .sb_tc_pool_bind_get = dsa_devlink_sb_tc_pool_bind_get, + .sb_tc_pool_bind_set = dsa_devlink_sb_tc_pool_bind_set, + .sb_occ_snapshot = dsa_devlink_sb_occ_snapshot, + .sb_occ_max_clear = dsa_devlink_sb_occ_max_clear, + .sb_occ_port_pool_get = dsa_devlink_sb_occ_port_pool_get, + .sb_occ_tc_port_bind_get = dsa_devlink_sb_occ_tc_port_bind_get, +}; + +int dsa_devlink_param_get(struct devlink *dl, u32 id, + struct devlink_param_gset_ctx *ctx) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_param_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_param_get(ds, id, ctx); +} +EXPORT_SYMBOL_GPL(dsa_devlink_param_get); + +int dsa_devlink_param_set(struct devlink *dl, u32 id, + struct devlink_param_gset_ctx *ctx) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_param_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_param_set(ds, id, ctx); +} +EXPORT_SYMBOL_GPL(dsa_devlink_param_set); + +int dsa_devlink_params_register(struct dsa_switch *ds, + const struct devlink_param *params, + size_t params_count) +{ + return devlink_params_register(ds->devlink, params, params_count); +} +EXPORT_SYMBOL_GPL(dsa_devlink_params_register); + +void dsa_devlink_params_unregister(struct dsa_switch *ds, + const struct devlink_param *params, + size_t params_count) +{ + devlink_params_unregister(ds->devlink, params, params_count); +} +EXPORT_SYMBOL_GPL(dsa_devlink_params_unregister); + +int dsa_devlink_resource_register(struct dsa_switch *ds, + const char *resource_name, + u64 resource_size, + u64 resource_id, + u64 parent_resource_id, + const struct devlink_resource_size_params *size_params) +{ + return devlink_resource_register(ds->devlink, resource_name, + resource_size, resource_id, + parent_resource_id, + size_params); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_register); + +void dsa_devlink_resources_unregister(struct dsa_switch *ds) +{ + devlink_resources_unregister(ds->devlink); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resources_unregister); + +void dsa_devlink_resource_occ_get_register(struct dsa_switch *ds, + u64 resource_id, + devlink_resource_occ_get_t *occ_get, + void *occ_get_priv) +{ + return devlink_resource_occ_get_register(ds->devlink, resource_id, + occ_get, occ_get_priv); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_register); + +void dsa_devlink_resource_occ_get_unregister(struct dsa_switch *ds, + u64 resource_id) +{ + devlink_resource_occ_get_unregister(ds->devlink, resource_id); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_unregister); + +struct devlink_region * +dsa_devlink_region_create(struct dsa_switch *ds, + const struct devlink_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + return devlink_region_create(ds->devlink, ops, region_max_snapshots, + region_size); +} +EXPORT_SYMBOL_GPL(dsa_devlink_region_create); + +struct devlink_region * +dsa_devlink_port_region_create(struct dsa_switch *ds, + int port, + const struct devlink_port_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + + return devlink_port_region_create(&dp->devlink_port, ops, + region_max_snapshots, + region_size); +} +EXPORT_SYMBOL_GPL(dsa_devlink_port_region_create); + +void dsa_devlink_region_destroy(struct devlink_region *region) +{ + devlink_region_destroy(region); +} +EXPORT_SYMBOL_GPL(dsa_devlink_region_destroy); + +int dsa_port_devlink_setup(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + struct dsa_switch_tree *dst = dp->ds->dst; + struct devlink_port_attrs attrs = {}; + struct devlink *dl = dp->ds->devlink; + struct dsa_switch *ds = dp->ds; + const unsigned char *id; + unsigned char len; + int err; + + memset(dlp, 0, sizeof(*dlp)); + devlink_port_init(dl, dlp); + + if (ds->ops->port_setup) { + err = ds->ops->port_setup(ds, dp->index); + if (err) + return err; + } + + id = (const unsigned char *)&dst->index; + len = sizeof(dst->index); + + attrs.phys.port_number = dp->index; + memcpy(attrs.switch_id.id, id, len); + attrs.switch_id.id_len = len; + + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + attrs.flavour = DEVLINK_PORT_FLAVOUR_UNUSED; + break; + case DSA_PORT_TYPE_CPU: + attrs.flavour = DEVLINK_PORT_FLAVOUR_CPU; + break; + case DSA_PORT_TYPE_DSA: + attrs.flavour = DEVLINK_PORT_FLAVOUR_DSA; + break; + case DSA_PORT_TYPE_USER: + attrs.flavour = DEVLINK_PORT_FLAVOUR_PHYSICAL; + break; + } + + devlink_port_attrs_set(dlp, &attrs); + err = devlink_port_register(dl, dlp, dp->index); + if (err) { + if (ds->ops->port_teardown) + ds->ops->port_teardown(ds, dp->index); + return err; + } + + return 0; +} + +void dsa_port_devlink_teardown(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + struct dsa_switch *ds = dp->ds; + + devlink_port_unregister(dlp); + + if (ds->ops->port_teardown) + ds->ops->port_teardown(ds, dp->index); + + devlink_port_fini(dlp); +} + +void dsa_switch_devlink_register(struct dsa_switch *ds) +{ + devlink_register(ds->devlink); +} + +void dsa_switch_devlink_unregister(struct dsa_switch *ds) +{ + devlink_unregister(ds->devlink); +} + +int dsa_switch_devlink_alloc(struct dsa_switch *ds) +{ + struct dsa_devlink_priv *dl_priv; + struct devlink *dl; + + /* Add the switch to devlink before calling setup, so that setup can + * add dpipe tables + */ + dl = devlink_alloc(&dsa_devlink_ops, sizeof(*dl_priv), ds->dev); + if (!dl) + return -ENOMEM; + + ds->devlink = dl; + + dl_priv = devlink_priv(ds->devlink); + dl_priv->ds = ds; + + return 0; +} + +void dsa_switch_devlink_free(struct dsa_switch *ds) +{ + devlink_free(ds->devlink); + ds->devlink = NULL; +} diff --git a/net/dsa/devlink.h b/net/dsa/devlink.h new file mode 100644 index 000000000000..4d9f4f23705b --- /dev/null +++ b/net/dsa/devlink.h @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_DEVLINK_H +#define __DSA_DEVLINK_H + +struct dsa_port; +struct dsa_switch; + +int dsa_port_devlink_setup(struct dsa_port *dp); +void dsa_port_devlink_teardown(struct dsa_port *dp); +void dsa_switch_devlink_register(struct dsa_switch *ds); +void dsa_switch_devlink_unregister(struct dsa_switch *ds); +int dsa_switch_devlink_alloc(struct dsa_switch *ds); +void dsa_switch_devlink_free(struct dsa_switch *ds); + +#endif diff --git a/net/dsa/dsa.c b/net/dsa/dsa.c index 64b14f655b23..e5f156940c67 100644 --- a/net/dsa/dsa.c +++ b/net/dsa/dsa.c @@ -1,453 +1,1637 @@ // SPDX-License-Identifier: GPL-2.0-or-later /* - * net/dsa/dsa.c - Hardware switch handling + * DSA topology and switch handling + * * Copyright (c) 2008-2009 Marvell Semiconductor * Copyright (c) 2013 Florian Fainelli <florian@openwrt.org> + * Copyright (c) 2016 Andrew Lunn <andrew@lunn.ch> */ #include <linux/device.h> +#include <linux/err.h> #include <linux/list.h> #include <linux/module.h> #include <linux/netdevice.h> -#include <linux/sysfs.h> -#include <linux/ptp_classify.h> +#include <linux/slab.h> +#include <linux/rtnetlink.h> +#include <linux/of.h> +#include <linux/of_mdio.h> +#include <linux/of_net.h> +#include <net/sch_generic.h> + +#include "devlink.h" +#include "dsa.h" +#include "master.h" +#include "netlink.h" +#include "port.h" +#include "slave.h" +#include "switch.h" +#include "tag.h" + +#define DSA_MAX_NUM_OFFLOADING_BRIDGES BITS_PER_LONG + +static DEFINE_MUTEX(dsa2_mutex); +LIST_HEAD(dsa_tree_list); -#include "dsa_priv.h" +static struct workqueue_struct *dsa_owq; -static LIST_HEAD(dsa_tag_drivers_list); -static DEFINE_MUTEX(dsa_tag_drivers_lock); +/* Track the bridges with forwarding offload enabled */ +static unsigned long dsa_fwd_offloading_bridges; -static struct sk_buff *dsa_slave_notag_xmit(struct sk_buff *skb, - struct net_device *dev) +bool dsa_schedule_work(struct work_struct *work) { - /* Just return the original SKB */ - return skb; + return queue_work(dsa_owq, work); } -static const struct dsa_device_ops none_ops = { - .name = "none", - .proto = DSA_TAG_PROTO_NONE, - .xmit = dsa_slave_notag_xmit, - .rcv = NULL, -}; +void dsa_flush_workqueue(void) +{ + flush_workqueue(dsa_owq); +} +EXPORT_SYMBOL_GPL(dsa_flush_workqueue); -DSA_TAG_DRIVER(none_ops); +/** + * dsa_lag_map() - Map LAG structure to a linear LAG array + * @dst: Tree in which to record the mapping. + * @lag: LAG structure that is to be mapped to the tree's array. + * + * dsa_lag_id/dsa_lag_by_id can then be used to translate between the + * two spaces. The size of the mapping space is determined by the + * driver by setting ds->num_lag_ids. It is perfectly legal to leave + * it unset if it is not needed, in which case these functions become + * no-ops. + */ +void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag) +{ + unsigned int id; + + for (id = 1; id <= dst->lags_len; id++) { + if (!dsa_lag_by_id(dst, id)) { + dst->lags[id - 1] = lag; + lag->id = id; + return; + } + } -static void dsa_tag_driver_register(struct dsa_tag_driver *dsa_tag_driver, - struct module *owner) + /* No IDs left, which is OK. Some drivers do not need it. The + * ones that do, e.g. mv88e6xxx, will discover that dsa_lag_id + * returns an error for this device when joining the LAG. The + * driver can then return -EOPNOTSUPP back to DSA, which will + * fall back to a software LAG. + */ +} + +/** + * dsa_lag_unmap() - Remove a LAG ID mapping + * @dst: Tree in which the mapping is recorded. + * @lag: LAG structure that was mapped. + * + * As there may be multiple users of the mapping, it is only removed + * if there are no other references to it. + */ +void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag) { - dsa_tag_driver->owner = owner; + unsigned int id; - mutex_lock(&dsa_tag_drivers_lock); - list_add_tail(&dsa_tag_driver->list, &dsa_tag_drivers_list); - mutex_unlock(&dsa_tag_drivers_lock); + dsa_lags_foreach_id(id, dst) { + if (dsa_lag_by_id(dst, id) == lag) { + dst->lags[id - 1] = NULL; + lag->id = 0; + break; + } + } } -void dsa_tag_drivers_register(struct dsa_tag_driver *dsa_tag_driver_array[], - unsigned int count, struct module *owner) +struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, + const struct net_device *lag_dev) { - unsigned int i; + struct dsa_port *dp; - for (i = 0; i < count; i++) - dsa_tag_driver_register(dsa_tag_driver_array[i], owner); + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_lag_dev_get(dp) == lag_dev) + return dp->lag; + + return NULL; } -static void dsa_tag_driver_unregister(struct dsa_tag_driver *dsa_tag_driver) +struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, + const struct net_device *br) { - mutex_lock(&dsa_tag_drivers_lock); - list_del(&dsa_tag_driver->list); - mutex_unlock(&dsa_tag_drivers_lock); + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_bridge_dev_get(dp) == br) + return dp->bridge; + + return NULL; } -EXPORT_SYMBOL_GPL(dsa_tag_drivers_register); -void dsa_tag_drivers_unregister(struct dsa_tag_driver *dsa_tag_driver_array[], - unsigned int count) +static int dsa_bridge_num_find(const struct net_device *bridge_dev) { - unsigned int i; + struct dsa_switch_tree *dst; - for (i = 0; i < count; i++) - dsa_tag_driver_unregister(dsa_tag_driver_array[i]); + list_for_each_entry(dst, &dsa_tree_list, list) { + struct dsa_bridge *bridge; + + bridge = dsa_tree_bridge_find(dst, bridge_dev); + if (bridge) + return bridge->num; + } + + return 0; } -EXPORT_SYMBOL_GPL(dsa_tag_drivers_unregister); -const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops) +unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max) { - return ops->name; -}; + unsigned int bridge_num = dsa_bridge_num_find(bridge_dev); + + /* Switches without FDB isolation support don't get unique + * bridge numbering + */ + if (!max) + return 0; + + if (!bridge_num) { + /* First port that requests FDB isolation or TX forwarding + * offload for this bridge + */ + bridge_num = find_next_zero_bit(&dsa_fwd_offloading_bridges, + DSA_MAX_NUM_OFFLOADING_BRIDGES, + 1); + if (bridge_num >= max) + return 0; + + set_bit(bridge_num, &dsa_fwd_offloading_bridges); + } + + return bridge_num; +} + +void dsa_bridge_num_put(const struct net_device *bridge_dev, + unsigned int bridge_num) +{ + /* Since we refcount bridges, we know that when we call this function + * it is no longer in use, so we can just go ahead and remove it from + * the bit mask. + */ + clear_bit(bridge_num, &dsa_fwd_offloading_bridges); +} + +struct dsa_switch *dsa_switch_find(int tree_index, int sw_index) +{ + struct dsa_switch_tree *dst; + struct dsa_port *dp; + + list_for_each_entry(dst, &dsa_tree_list, list) { + if (dst->index != tree_index) + continue; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->ds->index != sw_index) + continue; + + return dp->ds; + } + } + + return NULL; +} +EXPORT_SYMBOL_GPL(dsa_switch_find); + +static struct dsa_switch_tree *dsa_tree_find(int index) +{ + struct dsa_switch_tree *dst; + + list_for_each_entry(dst, &dsa_tree_list, list) + if (dst->index == index) + return dst; + + return NULL; +} + +static struct dsa_switch_tree *dsa_tree_alloc(int index) +{ + struct dsa_switch_tree *dst; + + dst = kzalloc(sizeof(*dst), GFP_KERNEL); + if (!dst) + return NULL; + + dst->index = index; + + INIT_LIST_HEAD(&dst->rtable); + + INIT_LIST_HEAD(&dst->ports); + + INIT_LIST_HEAD(&dst->list); + list_add_tail(&dst->list, &dsa_tree_list); + + kref_init(&dst->refcount); -/* Function takes a reference on the module owning the tagger, - * so dsa_tag_driver_put must be called afterwards. + return dst; +} + +static void dsa_tree_free(struct dsa_switch_tree *dst) +{ + if (dst->tag_ops) + dsa_tag_driver_put(dst->tag_ops); + list_del(&dst->list); + kfree(dst); +} + +static struct dsa_switch_tree *dsa_tree_get(struct dsa_switch_tree *dst) +{ + if (dst) + kref_get(&dst->refcount); + + return dst; +} + +static struct dsa_switch_tree *dsa_tree_touch(int index) +{ + struct dsa_switch_tree *dst; + + dst = dsa_tree_find(index); + if (dst) + return dsa_tree_get(dst); + else + return dsa_tree_alloc(index); +} + +static void dsa_tree_release(struct kref *ref) +{ + struct dsa_switch_tree *dst; + + dst = container_of(ref, struct dsa_switch_tree, refcount); + + dsa_tree_free(dst); +} + +static void dsa_tree_put(struct dsa_switch_tree *dst) +{ + if (dst) + kref_put(&dst->refcount, dsa_tree_release); +} + +static struct dsa_port *dsa_tree_find_port_by_node(struct dsa_switch_tree *dst, + struct device_node *dn) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dp->dn == dn) + return dp; + + return NULL; +} + +static struct dsa_link *dsa_link_touch(struct dsa_port *dp, + struct dsa_port *link_dp) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst; + struct dsa_link *dl; + + dst = ds->dst; + + list_for_each_entry(dl, &dst->rtable, list) + if (dl->dp == dp && dl->link_dp == link_dp) + return dl; + + dl = kzalloc(sizeof(*dl), GFP_KERNEL); + if (!dl) + return NULL; + + dl->dp = dp; + dl->link_dp = link_dp; + + INIT_LIST_HEAD(&dl->list); + list_add_tail(&dl->list, &dst->rtable); + + return dl; +} + +static bool dsa_port_setup_routing_table(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst = ds->dst; + struct device_node *dn = dp->dn; + struct of_phandle_iterator it; + struct dsa_port *link_dp; + struct dsa_link *dl; + int err; + + of_for_each_phandle(&it, err, dn, "link", NULL, 0) { + link_dp = dsa_tree_find_port_by_node(dst, it.node); + if (!link_dp) { + of_node_put(it.node); + return false; + } + + dl = dsa_link_touch(dp, link_dp); + if (!dl) { + of_node_put(it.node); + return false; + } + } + + return true; +} + +static bool dsa_tree_setup_routing_table(struct dsa_switch_tree *dst) +{ + bool complete = true; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_dsa(dp)) { + complete = dsa_port_setup_routing_table(dp); + if (!complete) + break; + } + } + + return complete; +} + +static struct dsa_port *dsa_tree_find_first_cpu(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_cpu(dp)) + return dp; + + return NULL; +} + +struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst) +{ + struct device_node *ethernet; + struct net_device *master; + struct dsa_port *cpu_dp; + + cpu_dp = dsa_tree_find_first_cpu(dst); + ethernet = of_parse_phandle(cpu_dp->dn, "ethernet", 0); + master = of_find_net_device_by_node(ethernet); + of_node_put(ethernet); + + return master; +} + +/* Assign the default CPU port (the first one in the tree) to all ports of the + * fabric which don't already have one as part of their own switch. */ -const struct dsa_device_ops *dsa_find_tagger_by_name(const char *buf) +static int dsa_tree_setup_default_cpu(struct dsa_switch_tree *dst) { - const struct dsa_device_ops *ops = ERR_PTR(-ENOPROTOOPT); - struct dsa_tag_driver *dsa_tag_driver; + struct dsa_port *cpu_dp, *dp; - mutex_lock(&dsa_tag_drivers_lock); - list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { - const struct dsa_device_ops *tmp = dsa_tag_driver->ops; + cpu_dp = dsa_tree_find_first_cpu(dst); + if (!cpu_dp) { + pr_err("DSA: tree %d has no CPU port\n", dst->index); + return -EINVAL; + } - if (!sysfs_streq(buf, tmp->name)) + list_for_each_entry(dp, &dst->ports, list) { + if (dp->cpu_dp) continue; - if (!try_module_get(dsa_tag_driver->owner)) - break; + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = cpu_dp; + } - ops = tmp; - break; + return 0; +} + +/* Perform initial assignment of CPU ports to user ports and DSA links in the + * fabric, giving preference to CPU ports local to each switch. Default to + * using the first CPU port in the switch tree if the port does not have a CPU + * port local to this switch. + */ +static int dsa_tree_setup_cpu_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp, *dp; + + list_for_each_entry(cpu_dp, &dst->ports, list) { + if (!dsa_port_is_cpu(cpu_dp)) + continue; + + /* Prefer a local CPU port */ + dsa_switch_for_each_port(dp, cpu_dp->ds) { + /* Prefer the first local CPU port found */ + if (dp->cpu_dp) + continue; + + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = cpu_dp; + } } - mutex_unlock(&dsa_tag_drivers_lock); - return ops; + return dsa_tree_setup_default_cpu(dst); } -const struct dsa_device_ops *dsa_tag_driver_get(int tag_protocol) +static void dsa_tree_teardown_cpu_ports(struct dsa_switch_tree *dst) { - struct dsa_tag_driver *dsa_tag_driver; - const struct dsa_device_ops *ops; - bool found = false; + struct dsa_port *dp; - request_module("%s%d", DSA_TAG_DRIVER_ALIAS, tag_protocol); + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = NULL; +} - mutex_lock(&dsa_tag_drivers_lock); - list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { - ops = dsa_tag_driver->ops; - if (ops->proto == tag_protocol) { - found = true; +static int dsa_port_setup(struct dsa_port *dp) +{ + bool dsa_port_link_registered = false; + struct dsa_switch *ds = dp->ds; + bool dsa_port_enabled = false; + int err = 0; + + if (dp->setup) + return 0; + + err = dsa_port_devlink_setup(dp); + if (err) + return err; + + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + dsa_port_disable(dp); + break; + case DSA_PORT_TYPE_CPU: + if (dp->dn) { + err = dsa_shared_port_link_register_of(dp); + if (err) + break; + dsa_port_link_registered = true; + } else { + dev_warn(ds->dev, + "skipping link registration for CPU port %d\n", + dp->index); + } + + err = dsa_port_enable(dp, NULL); + if (err) break; + dsa_port_enabled = true; + + break; + case DSA_PORT_TYPE_DSA: + if (dp->dn) { + err = dsa_shared_port_link_register_of(dp); + if (err) + break; + dsa_port_link_registered = true; + } else { + dev_warn(ds->dev, + "skipping link registration for DSA port %d\n", + dp->index); } + + err = dsa_port_enable(dp, NULL); + if (err) + break; + dsa_port_enabled = true; + + break; + case DSA_PORT_TYPE_USER: + of_get_mac_address(dp->dn, dp->mac); + err = dsa_slave_create(dp); + break; } - if (found) { - if (!try_module_get(dsa_tag_driver->owner)) - ops = ERR_PTR(-ENOPROTOOPT); - } else { - ops = ERR_PTR(-ENOPROTOOPT); + if (err && dsa_port_enabled) + dsa_port_disable(dp); + if (err && dsa_port_link_registered) + dsa_shared_port_link_unregister_of(dp); + if (err) { + dsa_port_devlink_teardown(dp); + return err; } - mutex_unlock(&dsa_tag_drivers_lock); + dp->setup = true; - return ops; + return 0; } -void dsa_tag_driver_put(const struct dsa_device_ops *ops) +static void dsa_port_teardown(struct dsa_port *dp) { - struct dsa_tag_driver *dsa_tag_driver; + if (!dp->setup) + return; - mutex_lock(&dsa_tag_drivers_lock); - list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { - if (dsa_tag_driver->ops == ops) { - module_put(dsa_tag_driver->owner); - break; + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + break; + case DSA_PORT_TYPE_CPU: + dsa_port_disable(dp); + if (dp->dn) + dsa_shared_port_link_unregister_of(dp); + break; + case DSA_PORT_TYPE_DSA: + dsa_port_disable(dp); + if (dp->dn) + dsa_shared_port_link_unregister_of(dp); + break; + case DSA_PORT_TYPE_USER: + if (dp->slave) { + dsa_slave_destroy(dp->slave); + dp->slave = NULL; } + break; } - mutex_unlock(&dsa_tag_drivers_lock); + + dsa_port_devlink_teardown(dp); + + dp->setup = false; } -static int dev_is_class(struct device *dev, void *class) +static int dsa_port_setup_as_unused(struct dsa_port *dp) { - if (dev->class != NULL && !strcmp(dev->class->name, class)) - return 1; + dp->type = DSA_PORT_TYPE_UNUSED; + return dsa_port_setup(dp); +} + +static int dsa_switch_setup_tag_protocol(struct dsa_switch *ds) +{ + const struct dsa_device_ops *tag_ops = ds->dst->tag_ops; + struct dsa_switch_tree *dst = ds->dst; + int err; + + if (tag_ops->proto == dst->default_proto) + goto connect; + + rtnl_lock(); + err = ds->ops->change_tag_protocol(ds, tag_ops->proto); + rtnl_unlock(); + if (err) { + dev_err(ds->dev, "Unable to use tag protocol \"%s\": %pe\n", + tag_ops->name, ERR_PTR(err)); + return err; + } + +connect: + if (tag_ops->connect) { + err = tag_ops->connect(ds); + if (err) + return err; + } + + if (ds->ops->connect_tag_protocol) { + err = ds->ops->connect_tag_protocol(ds, tag_ops->proto); + if (err) { + dev_err(ds->dev, + "Unable to connect to tag protocol \"%s\": %pe\n", + tag_ops->name, ERR_PTR(err)); + goto disconnect; + } + } return 0; + +disconnect: + if (tag_ops->disconnect) + tag_ops->disconnect(ds); + + return err; } -static struct device *dev_find_class(struct device *parent, char *class) +static void dsa_switch_teardown_tag_protocol(struct dsa_switch *ds) { - if (dev_is_class(parent, class)) { - get_device(parent); - return parent; - } + const struct dsa_device_ops *tag_ops = ds->dst->tag_ops; - return device_find_child(parent, class, dev_is_class); + if (tag_ops->disconnect) + tag_ops->disconnect(ds); } -struct net_device *dsa_dev_to_net_device(struct device *dev) +static int dsa_switch_setup(struct dsa_switch *ds) { - struct device *d; + struct device_node *dn; + int err; - d = dev_find_class(dev, "net"); - if (d != NULL) { - struct net_device *nd; + if (ds->setup) + return 0; - nd = to_net_dev(d); - dev_hold(nd); - put_device(d); + /* Initialize ds->phys_mii_mask before registering the slave MDIO bus + * driver and before ops->setup() has run, since the switch drivers and + * the slave MDIO bus driver rely on these values for probing PHY + * devices or not + */ + ds->phys_mii_mask |= dsa_user_ports(ds); - return nd; + err = dsa_switch_devlink_alloc(ds); + if (err) + return err; + + err = dsa_switch_register_notifier(ds); + if (err) + goto devlink_free; + + ds->configure_vlan_while_not_filtering = true; + + err = ds->ops->setup(ds); + if (err < 0) + goto unregister_notifier; + + err = dsa_switch_setup_tag_protocol(ds); + if (err) + goto teardown; + + if (!ds->slave_mii_bus && ds->ops->phy_read) { + ds->slave_mii_bus = mdiobus_alloc(); + if (!ds->slave_mii_bus) { + err = -ENOMEM; + goto teardown; + } + + dsa_slave_mii_bus_init(ds); + + dn = of_get_child_by_name(ds->dev->of_node, "mdio"); + + err = of_mdiobus_register(ds->slave_mii_bus, dn); + of_node_put(dn); + if (err < 0) + goto free_slave_mii_bus; } - return NULL; + dsa_switch_devlink_register(ds); + + ds->setup = true; + return 0; + +free_slave_mii_bus: + if (ds->slave_mii_bus && ds->ops->phy_read) + mdiobus_free(ds->slave_mii_bus); +teardown: + if (ds->ops->teardown) + ds->ops->teardown(ds); +unregister_notifier: + dsa_switch_unregister_notifier(ds); +devlink_free: + dsa_switch_devlink_free(ds); + return err; } -EXPORT_SYMBOL_GPL(dsa_dev_to_net_device); -/* Determine if we should defer delivery of skb until we have a rx timestamp. - * - * Called from dsa_switch_rcv. For now, this will only work if tagging is - * enabled on the switch. Normally the MAC driver would retrieve the hardware - * timestamp when it reads the packet out of the hardware. However in a DSA - * switch, the DSA driver owning the interface to which the packet is - * delivered is never notified unless we do so here. +static void dsa_switch_teardown(struct dsa_switch *ds) +{ + if (!ds->setup) + return; + + dsa_switch_devlink_unregister(ds); + + if (ds->slave_mii_bus && ds->ops->phy_read) { + mdiobus_unregister(ds->slave_mii_bus); + mdiobus_free(ds->slave_mii_bus); + ds->slave_mii_bus = NULL; + } + + dsa_switch_teardown_tag_protocol(ds); + + if (ds->ops->teardown) + ds->ops->teardown(ds); + + dsa_switch_unregister_notifier(ds); + + dsa_switch_devlink_free(ds); + + ds->setup = false; +} + +/* First tear down the non-shared, then the shared ports. This ensures that + * all work items scheduled by our switchdev handlers for user ports have + * completed before we destroy the refcounting kept on the shared ports. */ -static bool dsa_skb_defer_rx_timestamp(struct dsa_slave_priv *p, - struct sk_buff *skb) +static void dsa_tree_teardown_ports(struct dsa_switch_tree *dst) { - struct dsa_switch *ds = p->dp->ds; - unsigned int type; + struct dsa_port *dp; - if (skb_headroom(skb) < ETH_HLEN) - return false; + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) + dsa_port_teardown(dp); + + dsa_flush_workqueue(); + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) + dsa_port_teardown(dp); +} - __skb_push(skb, ETH_HLEN); +static void dsa_tree_teardown_switches(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; - type = ptp_classify_raw(skb); + list_for_each_entry(dp, &dst->ports, list) + dsa_switch_teardown(dp->ds); +} - __skb_pull(skb, ETH_HLEN); +/* Bring shared ports up first, then non-shared ports */ +static int dsa_tree_setup_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + int err = 0; - if (type == PTP_CLASS_NONE) - return false; + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) { + err = dsa_port_setup(dp); + if (err) + goto teardown; + } + } - if (likely(ds->ops->port_rxtstamp)) - return ds->ops->port_rxtstamp(ds, p->dp->index, skb, type); + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) { + err = dsa_port_setup(dp); + if (err) { + err = dsa_port_setup_as_unused(dp); + if (err) + goto teardown; + } + } + } - return false; + return 0; + +teardown: + dsa_tree_teardown_ports(dst); + + return err; } -static int dsa_switch_rcv(struct sk_buff *skb, struct net_device *dev, - struct packet_type *pt, struct net_device *unused) +static int dsa_tree_setup_switches(struct dsa_switch_tree *dst) { - struct dsa_port *cpu_dp = dev->dsa_ptr; - struct sk_buff *nskb = NULL; - struct dsa_slave_priv *p; + struct dsa_port *dp; + int err = 0; - if (unlikely(!cpu_dp)) { - kfree_skb(skb); - return 0; + list_for_each_entry(dp, &dst->ports, list) { + err = dsa_switch_setup(dp->ds); + if (err) { + dsa_tree_teardown_switches(dst); + break; + } } - skb = skb_unshare(skb, GFP_ATOMIC); - if (!skb) - return 0; + return err; +} - nskb = cpu_dp->rcv(skb, dev); - if (!nskb) { - kfree_skb(skb); - return 0; +static int dsa_tree_setup_master(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp; + int err = 0; + + rtnl_lock(); + + dsa_tree_for_each_cpu_port(cpu_dp, dst) { + struct net_device *master = cpu_dp->master; + bool admin_up = (master->flags & IFF_UP) && + !qdisc_tx_is_noop(master); + + err = dsa_master_setup(master, cpu_dp); + if (err) + break; + + /* Replay master state event */ + dsa_tree_master_admin_state_change(dst, master, admin_up); + dsa_tree_master_oper_state_change(dst, master, + netif_oper_up(master)); } - skb = nskb; - skb_push(skb, ETH_HLEN); - skb->pkt_type = PACKET_HOST; - skb->protocol = eth_type_trans(skb, skb->dev); + rtnl_unlock(); + + return err; +} + +static void dsa_tree_teardown_master(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp; - if (unlikely(!dsa_slave_dev_check(skb->dev))) { - /* Packet is to be injected directly on an upper - * device, e.g. a team/bond, so skip all DSA-port - * specific actions. + rtnl_lock(); + + dsa_tree_for_each_cpu_port(cpu_dp, dst) { + struct net_device *master = cpu_dp->master; + + /* Synthesizing an "admin down" state is sufficient for + * the switches to get a notification if the master is + * currently up and running. */ - netif_rx(skb); - return 0; + dsa_tree_master_admin_state_change(dst, master, false); + + dsa_master_teardown(master); } - p = netdev_priv(skb->dev); + rtnl_unlock(); +} - if (unlikely(cpu_dp->ds->untag_bridge_pvid)) { - nskb = dsa_untag_bridge_pvid(skb); - if (!nskb) { - kfree_skb(skb); - return 0; - } - skb = nskb; +static int dsa_tree_setup_lags(struct dsa_switch_tree *dst) +{ + unsigned int len = 0; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->ds->num_lag_ids > len) + len = dp->ds->num_lag_ids; } - dev_sw_netstats_rx_add(skb->dev, skb->len); + if (!len) + return 0; + + dst->lags = kcalloc(len, sizeof(*dst->lags), GFP_KERNEL); + if (!dst->lags) + return -ENOMEM; + + dst->lags_len = len; + return 0; +} + +static void dsa_tree_teardown_lags(struct dsa_switch_tree *dst) +{ + kfree(dst->lags); +} + +static int dsa_tree_setup(struct dsa_switch_tree *dst) +{ + bool complete; + int err; + + if (dst->setup) { + pr_err("DSA: tree %d already setup! Disjoint trees?\n", + dst->index); + return -EEXIST; + } - if (dsa_skb_defer_rx_timestamp(p, skb)) + complete = dsa_tree_setup_routing_table(dst); + if (!complete) return 0; - gro_cells_receive(&p->gcells, skb); + err = dsa_tree_setup_cpu_ports(dst); + if (err) + return err; + + err = dsa_tree_setup_switches(dst); + if (err) + goto teardown_cpu_ports; + + err = dsa_tree_setup_ports(dst); + if (err) + goto teardown_switches; + + err = dsa_tree_setup_master(dst); + if (err) + goto teardown_ports; + + err = dsa_tree_setup_lags(dst); + if (err) + goto teardown_master; + + dst->setup = true; + + pr_info("DSA: tree %d setup\n", dst->index); return 0; + +teardown_master: + dsa_tree_teardown_master(dst); +teardown_ports: + dsa_tree_teardown_ports(dst); +teardown_switches: + dsa_tree_teardown_switches(dst); +teardown_cpu_ports: + dsa_tree_teardown_cpu_ports(dst); + + return err; } -#ifdef CONFIG_PM_SLEEP -static bool dsa_port_is_initialized(const struct dsa_port *dp) +static void dsa_tree_teardown(struct dsa_switch_tree *dst) { - return dp->type == DSA_PORT_TYPE_USER && dp->slave; + struct dsa_link *dl, *next; + + if (!dst->setup) + return; + + dsa_tree_teardown_lags(dst); + + dsa_tree_teardown_master(dst); + + dsa_tree_teardown_ports(dst); + + dsa_tree_teardown_switches(dst); + + dsa_tree_teardown_cpu_ports(dst); + + list_for_each_entry_safe(dl, next, &dst->rtable, list) { + list_del(&dl->list); + kfree(dl); + } + + pr_info("DSA: tree %d torn down\n", dst->index); + + dst->setup = false; } -int dsa_switch_suspend(struct dsa_switch *ds) +static int dsa_tree_bind_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops) +{ + const struct dsa_device_ops *old_tag_ops = dst->tag_ops; + struct dsa_notifier_tag_proto_info info; + int err; + + dst->tag_ops = tag_ops; + + /* Notify the switches from this tree about the connection + * to the new tagger + */ + info.tag_ops = tag_ops; + err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_CONNECT, &info); + if (err && err != -EOPNOTSUPP) + goto out_disconnect; + + /* Notify the old tagger about the disconnection from this tree */ + info.tag_ops = old_tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); + + return 0; + +out_disconnect: + info.tag_ops = tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); + dst->tag_ops = old_tag_ops; + + return err; +} + +/* Since the dsa/tagging sysfs device attribute is per master, the assumption + * is that all DSA switches within a tree share the same tagger, otherwise + * they would have formed disjoint trees (different "dsa,member" values). + */ +int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops, + const struct dsa_device_ops *old_tag_ops) { + struct dsa_notifier_tag_proto_info info; struct dsa_port *dp; - int ret = 0; + int err = -EBUSY; + + if (!rtnl_trylock()) + return restart_syscall(); + + /* At the moment we don't allow changing the tag protocol under + * traffic. The rtnl_mutex also happens to serialize concurrent + * attempts to change the tagging protocol. If we ever lift the IFF_UP + * restriction, there needs to be another mutex which serializes this. + */ + dsa_tree_for_each_user_port(dp, dst) { + if (dsa_port_to_master(dp)->flags & IFF_UP) + goto out_unlock; + + if (dp->slave->flags & IFF_UP) + goto out_unlock; + } - /* Suspend slave network devices */ - dsa_switch_for_each_port(dp, ds) { - if (!dsa_port_is_initialized(dp)) - continue; + /* Notify the tag protocol change */ + info.tag_ops = tag_ops; + err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); + if (err) + goto out_unwind_tagger; - ret = dsa_slave_suspend(dp->slave); - if (ret) - return ret; + err = dsa_tree_bind_tag_proto(dst, tag_ops); + if (err) + goto out_unwind_tagger; + + rtnl_unlock(); + + return 0; + +out_unwind_tagger: + info.tag_ops = old_tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); +out_unlock: + rtnl_unlock(); + return err; +} + +static void dsa_tree_master_state_change(struct dsa_switch_tree *dst, + struct net_device *master) +{ + struct dsa_notifier_master_state_info info; + struct dsa_port *cpu_dp = master->dsa_ptr; + + info.master = master; + info.operational = dsa_port_master_is_operational(cpu_dp); + + dsa_tree_notify(dst, DSA_NOTIFIER_MASTER_STATE_CHANGE, &info); +} + +void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + bool notify = false; + + /* Don't keep track of admin state on LAG DSA masters, + * but rather just of physical DSA masters + */ + if (netif_is_lag_master(master)) + return; + + if ((dsa_port_master_is_operational(cpu_dp)) != + (up && cpu_dp->master_oper_up)) + notify = true; + + cpu_dp->master_admin_up = up; + + if (notify) + dsa_tree_master_state_change(dst, master); +} + +void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + bool notify = false; + + /* Don't keep track of oper state on LAG DSA masters, + * but rather just of physical DSA masters + */ + if (netif_is_lag_master(master)) + return; + + if ((dsa_port_master_is_operational(cpu_dp)) != + (cpu_dp->master_admin_up && up)) + notify = true; + + cpu_dp->master_oper_up = up; + + if (notify) + dsa_tree_master_state_change(dst, master); +} + +static struct dsa_port *dsa_port_touch(struct dsa_switch *ds, int index) +{ + struct dsa_switch_tree *dst = ds->dst; + struct dsa_port *dp; + + dsa_switch_for_each_port(dp, ds) + if (dp->index == index) + return dp; + + dp = kzalloc(sizeof(*dp), GFP_KERNEL); + if (!dp) + return NULL; + + dp->ds = ds; + dp->index = index; + + mutex_init(&dp->addr_lists_lock); + mutex_init(&dp->vlans_lock); + INIT_LIST_HEAD(&dp->fdbs); + INIT_LIST_HEAD(&dp->mdbs); + INIT_LIST_HEAD(&dp->vlans); + INIT_LIST_HEAD(&dp->list); + list_add_tail(&dp->list, &dst->ports); + + return dp; +} + +static int dsa_port_parse_user(struct dsa_port *dp, const char *name) +{ + dp->type = DSA_PORT_TYPE_USER; + dp->name = name; + + return 0; +} + +static int dsa_port_parse_dsa(struct dsa_port *dp) +{ + dp->type = DSA_PORT_TYPE_DSA; + + return 0; +} + +static enum dsa_tag_protocol dsa_get_tag_protocol(struct dsa_port *dp, + struct net_device *master) +{ + enum dsa_tag_protocol tag_protocol = DSA_TAG_PROTO_NONE; + struct dsa_switch *mds, *ds = dp->ds; + unsigned int mdp_upstream; + struct dsa_port *mdp; + + /* It is possible to stack DSA switches onto one another when that + * happens the switch driver may want to know if its tagging protocol + * is going to work in such a configuration. + */ + if (dsa_slave_dev_check(master)) { + mdp = dsa_slave_to_port(master); + mds = mdp->ds; + mdp_upstream = dsa_upstream_port(mds, mdp->index); + tag_protocol = mds->ops->get_tag_protocol(mds, mdp_upstream, + DSA_TAG_PROTO_NONE); } - if (ds->ops->suspend) - ret = ds->ops->suspend(ds); + /* If the master device is not itself a DSA slave in a disjoint DSA + * tree, then return immediately. + */ + return ds->ops->get_tag_protocol(ds, dp->index, tag_protocol); +} - return ret; +static int dsa_port_parse_cpu(struct dsa_port *dp, struct net_device *master, + const char *user_protocol) +{ + const struct dsa_device_ops *tag_ops = NULL; + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst = ds->dst; + enum dsa_tag_protocol default_proto; + + /* Find out which protocol the switch would prefer. */ + default_proto = dsa_get_tag_protocol(dp, master); + if (dst->default_proto) { + if (dst->default_proto != default_proto) { + dev_err(ds->dev, + "A DSA switch tree can have only one tagging protocol\n"); + return -EINVAL; + } + } else { + dst->default_proto = default_proto; + } + + /* See if the user wants to override that preference. */ + if (user_protocol) { + if (!ds->ops->change_tag_protocol) { + dev_err(ds->dev, "Tag protocol cannot be modified\n"); + return -EINVAL; + } + + tag_ops = dsa_tag_driver_get_by_name(user_protocol); + if (IS_ERR(tag_ops)) { + dev_warn(ds->dev, + "Failed to find a tagging driver for protocol %s, using default\n", + user_protocol); + tag_ops = NULL; + } + } + + if (!tag_ops) + tag_ops = dsa_tag_driver_get_by_id(default_proto); + + if (IS_ERR(tag_ops)) { + if (PTR_ERR(tag_ops) == -ENOPROTOOPT) + return -EPROBE_DEFER; + + dev_warn(ds->dev, "No tagger for this switch\n"); + return PTR_ERR(tag_ops); + } + + if (dst->tag_ops) { + if (dst->tag_ops != tag_ops) { + dev_err(ds->dev, + "A DSA switch tree can have only one tagging protocol\n"); + + dsa_tag_driver_put(tag_ops); + return -EINVAL; + } + + /* In the case of multiple CPU ports per switch, the tagging + * protocol is still reference-counted only per switch tree. + */ + dsa_tag_driver_put(tag_ops); + } else { + dst->tag_ops = tag_ops; + } + + dp->master = master; + dp->type = DSA_PORT_TYPE_CPU; + dsa_port_set_tag_protocol(dp, dst->tag_ops); + dp->dst = dst; + + /* At this point, the tree may be configured to use a different + * tagger than the one chosen by the switch driver during + * .setup, in the case when a user selects a custom protocol + * through the DT. + * + * This is resolved by syncing the driver with the tree in + * dsa_switch_setup_tag_protocol once .setup has run and the + * driver is ready to accept calls to .change_tag_protocol. If + * the driver does not support the custom protocol at that + * point, the tree is wholly rejected, thereby ensuring that the + * tree and driver are always in agreement on the protocol to + * use. + */ + return 0; } -EXPORT_SYMBOL_GPL(dsa_switch_suspend); -int dsa_switch_resume(struct dsa_switch *ds) +static int dsa_port_parse_of(struct dsa_port *dp, struct device_node *dn) +{ + struct device_node *ethernet = of_parse_phandle(dn, "ethernet", 0); + const char *name = of_get_property(dn, "label", NULL); + bool link = of_property_read_bool(dn, "link"); + + dp->dn = dn; + + if (ethernet) { + struct net_device *master; + const char *user_protocol; + + master = of_find_net_device_by_node(ethernet); + of_node_put(ethernet); + if (!master) + return -EPROBE_DEFER; + + user_protocol = of_get_property(dn, "dsa-tag-protocol", NULL); + return dsa_port_parse_cpu(dp, master, user_protocol); + } + + if (link) + return dsa_port_parse_dsa(dp); + + return dsa_port_parse_user(dp, name); +} + +static int dsa_switch_parse_ports_of(struct dsa_switch *ds, + struct device_node *dn) { + struct device_node *ports, *port; struct dsa_port *dp; - int ret = 0; + int err = 0; + u32 reg; + + ports = of_get_child_by_name(dn, "ports"); + if (!ports) { + /* The second possibility is "ethernet-ports" */ + ports = of_get_child_by_name(dn, "ethernet-ports"); + if (!ports) { + dev_err(ds->dev, "no ports child node found\n"); + return -EINVAL; + } + } - if (ds->ops->resume) - ret = ds->ops->resume(ds); + for_each_available_child_of_node(ports, port) { + err = of_property_read_u32(port, "reg", ®); + if (err) { + of_node_put(port); + goto out_put_node; + } - if (ret) - return ret; + if (reg >= ds->num_ports) { + dev_err(ds->dev, "port %pOF index %u exceeds num_ports (%u)\n", + port, reg, ds->num_ports); + of_node_put(port); + err = -EINVAL; + goto out_put_node; + } - /* Resume slave network devices */ - dsa_switch_for_each_port(dp, ds) { - if (!dsa_port_is_initialized(dp)) - continue; + dp = dsa_to_port(ds, reg); - ret = dsa_slave_resume(dp->slave); - if (ret) - return ret; + err = dsa_port_parse_of(dp, port); + if (err) { + of_node_put(port); + goto out_put_node; + } + } + +out_put_node: + of_node_put(ports); + return err; +} + +static int dsa_switch_parse_member_of(struct dsa_switch *ds, + struct device_node *dn) +{ + u32 m[2] = { 0, 0 }; + int sz; + + /* Don't error out if this optional property isn't found */ + sz = of_property_read_variable_u32_array(dn, "dsa,member", m, 2, 2); + if (sz < 0 && sz != -EINVAL) + return sz; + + ds->index = m[1]; + + ds->dst = dsa_tree_touch(m[0]); + if (!ds->dst) + return -ENOMEM; + + if (dsa_switch_find(ds->dst->index, ds->index)) { + dev_err(ds->dev, + "A DSA switch with index %d already exists in tree %d\n", + ds->index, ds->dst->index); + return -EEXIST; } + if (ds->dst->last_switch < ds->index) + ds->dst->last_switch = ds->index; + return 0; } -EXPORT_SYMBOL_GPL(dsa_switch_resume); -#endif -static struct packet_type dsa_pack_type __read_mostly = { - .type = cpu_to_be16(ETH_P_XDSA), - .func = dsa_switch_rcv, -}; +static int dsa_switch_touch_ports(struct dsa_switch *ds) +{ + struct dsa_port *dp; + int port; -static struct workqueue_struct *dsa_owq; + for (port = 0; port < ds->num_ports; port++) { + dp = dsa_port_touch(ds, port); + if (!dp) + return -ENOMEM; + } -bool dsa_schedule_work(struct work_struct *work) + return 0; +} + +static int dsa_switch_parse_of(struct dsa_switch *ds, struct device_node *dn) { - return queue_work(dsa_owq, work); + int err; + + err = dsa_switch_parse_member_of(ds, dn); + if (err) + return err; + + err = dsa_switch_touch_ports(ds); + if (err) + return err; + + return dsa_switch_parse_ports_of(ds, dn); } -void dsa_flush_workqueue(void) +static int dev_is_class(struct device *dev, void *class) { - flush_workqueue(dsa_owq); + if (dev->class != NULL && !strcmp(dev->class->name, class)) + return 1; + + return 0; } -EXPORT_SYMBOL_GPL(dsa_flush_workqueue); -int dsa_devlink_param_get(struct devlink *dl, u32 id, - struct devlink_param_gset_ctx *ctx) +static struct device *dev_find_class(struct device *parent, char *class) { - struct dsa_switch *ds = dsa_devlink_to_ds(dl); + if (dev_is_class(parent, class)) { + get_device(parent); + return parent; + } - if (!ds->ops->devlink_param_get) - return -EOPNOTSUPP; + return device_find_child(parent, class, dev_is_class); +} + +static struct net_device *dsa_dev_to_net_device(struct device *dev) +{ + struct device *d; + + d = dev_find_class(dev, "net"); + if (d != NULL) { + struct net_device *nd; + + nd = to_net_dev(d); + dev_hold(nd); + put_device(d); - return ds->ops->devlink_param_get(ds, id, ctx); + return nd; + } + + return NULL; } -EXPORT_SYMBOL_GPL(dsa_devlink_param_get); -int dsa_devlink_param_set(struct devlink *dl, u32 id, - struct devlink_param_gset_ctx *ctx) +static int dsa_port_parse(struct dsa_port *dp, const char *name, + struct device *dev) { - struct dsa_switch *ds = dsa_devlink_to_ds(dl); + if (!strcmp(name, "cpu")) { + struct net_device *master; + + master = dsa_dev_to_net_device(dev); + if (!master) + return -EPROBE_DEFER; + + dev_put(master); - if (!ds->ops->devlink_param_set) - return -EOPNOTSUPP; + return dsa_port_parse_cpu(dp, master, NULL); + } + + if (!strcmp(name, "dsa")) + return dsa_port_parse_dsa(dp); - return ds->ops->devlink_param_set(ds, id, ctx); + return dsa_port_parse_user(dp, name); } -EXPORT_SYMBOL_GPL(dsa_devlink_param_set); -int dsa_devlink_params_register(struct dsa_switch *ds, - const struct devlink_param *params, - size_t params_count) +static int dsa_switch_parse_ports(struct dsa_switch *ds, + struct dsa_chip_data *cd) { - return devlink_params_register(ds->devlink, params, params_count); + bool valid_name_found = false; + struct dsa_port *dp; + struct device *dev; + const char *name; + unsigned int i; + int err; + + for (i = 0; i < DSA_MAX_PORTS; i++) { + name = cd->port_names[i]; + dev = cd->netdev[i]; + dp = dsa_to_port(ds, i); + + if (!name) + continue; + + err = dsa_port_parse(dp, name, dev); + if (err) + return err; + + valid_name_found = true; + } + + if (!valid_name_found && i == DSA_MAX_PORTS) + return -EINVAL; + + return 0; } -EXPORT_SYMBOL_GPL(dsa_devlink_params_register); -void dsa_devlink_params_unregister(struct dsa_switch *ds, - const struct devlink_param *params, - size_t params_count) +static int dsa_switch_parse(struct dsa_switch *ds, struct dsa_chip_data *cd) { - devlink_params_unregister(ds->devlink, params, params_count); + int err; + + ds->cd = cd; + + /* We don't support interconnected switches nor multiple trees via + * platform data, so this is the unique switch of the tree. + */ + ds->index = 0; + ds->dst = dsa_tree_touch(0); + if (!ds->dst) + return -ENOMEM; + + err = dsa_switch_touch_ports(ds); + if (err) + return err; + + return dsa_switch_parse_ports(ds, cd); } -EXPORT_SYMBOL_GPL(dsa_devlink_params_unregister); -int dsa_devlink_resource_register(struct dsa_switch *ds, - const char *resource_name, - u64 resource_size, - u64 resource_id, - u64 parent_resource_id, - const struct devlink_resource_size_params *size_params) +static void dsa_switch_release_ports(struct dsa_switch *ds) { - return devlink_resource_register(ds->devlink, resource_name, - resource_size, resource_id, - parent_resource_id, - size_params); + struct dsa_port *dp, *next; + + dsa_switch_for_each_port_safe(dp, next, ds) { + WARN_ON(!list_empty(&dp->fdbs)); + WARN_ON(!list_empty(&dp->mdbs)); + WARN_ON(!list_empty(&dp->vlans)); + list_del(&dp->list); + kfree(dp); + } } -EXPORT_SYMBOL_GPL(dsa_devlink_resource_register); -void dsa_devlink_resources_unregister(struct dsa_switch *ds) +static int dsa_switch_probe(struct dsa_switch *ds) { - devlink_resources_unregister(ds->devlink); + struct dsa_switch_tree *dst; + struct dsa_chip_data *pdata; + struct device_node *np; + int err; + + if (!ds->dev) + return -ENODEV; + + pdata = ds->dev->platform_data; + np = ds->dev->of_node; + + if (!ds->num_ports) + return -EINVAL; + + if (np) { + err = dsa_switch_parse_of(ds, np); + if (err) + dsa_switch_release_ports(ds); + } else if (pdata) { + err = dsa_switch_parse(ds, pdata); + if (err) + dsa_switch_release_ports(ds); + } else { + err = -ENODEV; + } + + if (err) + return err; + + dst = ds->dst; + dsa_tree_get(dst); + err = dsa_tree_setup(dst); + if (err) { + dsa_switch_release_ports(ds); + dsa_tree_put(dst); + } + + return err; } -EXPORT_SYMBOL_GPL(dsa_devlink_resources_unregister); -void dsa_devlink_resource_occ_get_register(struct dsa_switch *ds, - u64 resource_id, - devlink_resource_occ_get_t *occ_get, - void *occ_get_priv) +int dsa_register_switch(struct dsa_switch *ds) { - return devlink_resource_occ_get_register(ds->devlink, resource_id, - occ_get, occ_get_priv); + int err; + + mutex_lock(&dsa2_mutex); + err = dsa_switch_probe(ds); + dsa_tree_put(ds->dst); + mutex_unlock(&dsa2_mutex); + + return err; } -EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_register); +EXPORT_SYMBOL_GPL(dsa_register_switch); -void dsa_devlink_resource_occ_get_unregister(struct dsa_switch *ds, - u64 resource_id) +static void dsa_switch_remove(struct dsa_switch *ds) { - devlink_resource_occ_get_unregister(ds->devlink, resource_id); + struct dsa_switch_tree *dst = ds->dst; + + dsa_tree_teardown(dst); + dsa_switch_release_ports(ds); + dsa_tree_put(dst); } -EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_unregister); -struct devlink_region * -dsa_devlink_region_create(struct dsa_switch *ds, - const struct devlink_region_ops *ops, - u32 region_max_snapshots, u64 region_size) +void dsa_unregister_switch(struct dsa_switch *ds) { - return devlink_region_create(ds->devlink, ops, region_max_snapshots, - region_size); + mutex_lock(&dsa2_mutex); + dsa_switch_remove(ds); + mutex_unlock(&dsa2_mutex); } -EXPORT_SYMBOL_GPL(dsa_devlink_region_create); +EXPORT_SYMBOL_GPL(dsa_unregister_switch); -struct devlink_region * -dsa_devlink_port_region_create(struct dsa_switch *ds, - int port, - const struct devlink_port_region_ops *ops, - u32 region_max_snapshots, u64 region_size) +/* If the DSA master chooses to unregister its net_device on .shutdown, DSA is + * blocking that operation from completion, due to the dev_hold taken inside + * netdev_upper_dev_link. Unlink the DSA slave interfaces from being uppers of + * the DSA master, so that the system can reboot successfully. + */ +void dsa_switch_shutdown(struct dsa_switch *ds) { - struct dsa_port *dp = dsa_to_port(ds, port); + struct net_device *master, *slave_dev; + struct dsa_port *dp; + + mutex_lock(&dsa2_mutex); + + if (!ds->setup) + goto out; + + rtnl_lock(); + + dsa_switch_for_each_user_port(dp, ds) { + master = dsa_port_to_master(dp); + slave_dev = dp->slave; + + netdev_upper_dev_unlink(master, slave_dev); + } + + /* Disconnect from further netdevice notifiers on the master, + * since netdev_uses_dsa() will now return false. + */ + dsa_switch_for_each_cpu_port(dp, ds) + dp->master->dsa_ptr = NULL; + + rtnl_unlock(); +out: + mutex_unlock(&dsa2_mutex); +} +EXPORT_SYMBOL_GPL(dsa_switch_shutdown); + +#ifdef CONFIG_PM_SLEEP +static bool dsa_port_is_initialized(const struct dsa_port *dp) +{ + return dp->type == DSA_PORT_TYPE_USER && dp->slave; +} + +int dsa_switch_suspend(struct dsa_switch *ds) +{ + struct dsa_port *dp; + int ret = 0; + + /* Suspend slave network devices */ + dsa_switch_for_each_port(dp, ds) { + if (!dsa_port_is_initialized(dp)) + continue; + + ret = dsa_slave_suspend(dp->slave); + if (ret) + return ret; + } + + if (ds->ops->suspend) + ret = ds->ops->suspend(ds); - return devlink_port_region_create(&dp->devlink_port, ops, - region_max_snapshots, - region_size); + return ret; } -EXPORT_SYMBOL_GPL(dsa_devlink_port_region_create); +EXPORT_SYMBOL_GPL(dsa_switch_suspend); -void dsa_devlink_region_destroy(struct devlink_region *region) +int dsa_switch_resume(struct dsa_switch *ds) { - devlink_region_destroy(region); + struct dsa_port *dp; + int ret = 0; + + if (ds->ops->resume) + ret = ds->ops->resume(ds); + + if (ret) + return ret; + + /* Resume slave network devices */ + dsa_switch_for_each_port(dp, ds) { + if (!dsa_port_is_initialized(dp)) + continue; + + ret = dsa_slave_resume(dp->slave); + if (ret) + return ret; + } + + return 0; } -EXPORT_SYMBOL_GPL(dsa_devlink_region_destroy); +EXPORT_SYMBOL_GPL(dsa_switch_resume); +#endif struct dsa_port *dsa_port_from_netdev(struct net_device *netdev) { @@ -533,9 +1717,6 @@ static int __init dsa_init_module(void) dev_add_pack(&dsa_pack_type); - dsa_tag_driver_register(&DSA_TAG_DRIVER_NAME(none_ops), - THIS_MODULE); - rc = rtnl_link_register(&dsa_link_ops); if (rc) goto netlink_register_fail; @@ -543,7 +1724,6 @@ static int __init dsa_init_module(void) return 0; netlink_register_fail: - dsa_tag_driver_unregister(&DSA_TAG_DRIVER_NAME(none_ops)); dsa_slave_unregister_notifier(); dev_remove_pack(&dsa_pack_type); register_notifier_fail: @@ -556,7 +1736,6 @@ module_init(dsa_init_module); static void __exit dsa_cleanup_module(void) { rtnl_link_unregister(&dsa_link_ops); - dsa_tag_driver_unregister(&DSA_TAG_DRIVER_NAME(none_ops)); dsa_slave_unregister_notifier(); dev_remove_pack(&dsa_pack_type); diff --git a/net/dsa/dsa.h b/net/dsa/dsa.h new file mode 100644 index 000000000000..b7e17ae1094d --- /dev/null +++ b/net/dsa/dsa.h @@ -0,0 +1,40 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_H +#define __DSA_H + +#include <linux/list.h> +#include <linux/types.h> + +struct dsa_db; +struct dsa_device_ops; +struct dsa_lag; +struct dsa_switch_tree; +struct net_device; +struct work_struct; + +extern struct list_head dsa_tree_list; + +bool dsa_db_equal(const struct dsa_db *a, const struct dsa_db *b); +bool dsa_schedule_work(struct work_struct *work); +void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag); +void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag); +struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, + const struct net_device *lag_dev); +struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst); +int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops, + const struct dsa_device_ops *old_tag_ops); +void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up); +void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up); +unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max); +void dsa_bridge_num_put(const struct net_device *bridge_dev, + unsigned int bridge_num); +struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, + const struct net_device *br); + +#endif diff --git a/net/dsa/dsa2.c b/net/dsa/dsa2.c deleted file mode 100644 index e504a18fc125..000000000000 --- a/net/dsa/dsa2.c +++ /dev/null @@ -1,1819 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0-or-later -/* - * net/dsa/dsa2.c - Hardware switch handling, binding version 2 - * Copyright (c) 2008-2009 Marvell Semiconductor - * Copyright (c) 2013 Florian Fainelli <florian@openwrt.org> - * Copyright (c) 2016 Andrew Lunn <andrew@lunn.ch> - */ - -#include <linux/device.h> -#include <linux/err.h> -#include <linux/list.h> -#include <linux/netdevice.h> -#include <linux/slab.h> -#include <linux/rtnetlink.h> -#include <linux/of.h> -#include <linux/of_mdio.h> -#include <linux/of_net.h> -#include <net/devlink.h> -#include <net/sch_generic.h> - -#include "dsa_priv.h" - -static DEFINE_MUTEX(dsa2_mutex); -LIST_HEAD(dsa_tree_list); - -/* Track the bridges with forwarding offload enabled */ -static unsigned long dsa_fwd_offloading_bridges; - -/** - * dsa_tree_notify - Execute code for all switches in a DSA switch tree. - * @dst: collection of struct dsa_switch devices to notify. - * @e: event, must be of type DSA_NOTIFIER_* - * @v: event-specific value. - * - * Given a struct dsa_switch_tree, this can be used to run a function once for - * each member DSA switch. The other alternative of traversing the tree is only - * through its ports list, which does not uniquely list the switches. - */ -int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v) -{ - struct raw_notifier_head *nh = &dst->nh; - int err; - - err = raw_notifier_call_chain(nh, e, v); - - return notifier_to_errno(err); -} - -/** - * dsa_broadcast - Notify all DSA trees in the system. - * @e: event, must be of type DSA_NOTIFIER_* - * @v: event-specific value. - * - * Can be used to notify the switching fabric of events such as cross-chip - * bridging between disjoint trees (such as islands of tagger-compatible - * switches bridged by an incompatible middle switch). - * - * WARNING: this function is not reliable during probe time, because probing - * between trees is asynchronous and not all DSA trees might have probed. - */ -int dsa_broadcast(unsigned long e, void *v) -{ - struct dsa_switch_tree *dst; - int err = 0; - - list_for_each_entry(dst, &dsa_tree_list, list) { - err = dsa_tree_notify(dst, e, v); - if (err) - break; - } - - return err; -} - -/** - * dsa_lag_map() - Map LAG structure to a linear LAG array - * @dst: Tree in which to record the mapping. - * @lag: LAG structure that is to be mapped to the tree's array. - * - * dsa_lag_id/dsa_lag_by_id can then be used to translate between the - * two spaces. The size of the mapping space is determined by the - * driver by setting ds->num_lag_ids. It is perfectly legal to leave - * it unset if it is not needed, in which case these functions become - * no-ops. - */ -void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag) -{ - unsigned int id; - - for (id = 1; id <= dst->lags_len; id++) { - if (!dsa_lag_by_id(dst, id)) { - dst->lags[id - 1] = lag; - lag->id = id; - return; - } - } - - /* No IDs left, which is OK. Some drivers do not need it. The - * ones that do, e.g. mv88e6xxx, will discover that dsa_lag_id - * returns an error for this device when joining the LAG. The - * driver can then return -EOPNOTSUPP back to DSA, which will - * fall back to a software LAG. - */ -} - -/** - * dsa_lag_unmap() - Remove a LAG ID mapping - * @dst: Tree in which the mapping is recorded. - * @lag: LAG structure that was mapped. - * - * As there may be multiple users of the mapping, it is only removed - * if there are no other references to it. - */ -void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag) -{ - unsigned int id; - - dsa_lags_foreach_id(id, dst) { - if (dsa_lag_by_id(dst, id) == lag) { - dst->lags[id - 1] = NULL; - lag->id = 0; - break; - } - } -} - -struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, - const struct net_device *lag_dev) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_lag_dev_get(dp) == lag_dev) - return dp->lag; - - return NULL; -} - -struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, - const struct net_device *br) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_bridge_dev_get(dp) == br) - return dp->bridge; - - return NULL; -} - -static int dsa_bridge_num_find(const struct net_device *bridge_dev) -{ - struct dsa_switch_tree *dst; - - list_for_each_entry(dst, &dsa_tree_list, list) { - struct dsa_bridge *bridge; - - bridge = dsa_tree_bridge_find(dst, bridge_dev); - if (bridge) - return bridge->num; - } - - return 0; -} - -unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max) -{ - unsigned int bridge_num = dsa_bridge_num_find(bridge_dev); - - /* Switches without FDB isolation support don't get unique - * bridge numbering - */ - if (!max) - return 0; - - if (!bridge_num) { - /* First port that requests FDB isolation or TX forwarding - * offload for this bridge - */ - bridge_num = find_next_zero_bit(&dsa_fwd_offloading_bridges, - DSA_MAX_NUM_OFFLOADING_BRIDGES, - 1); - if (bridge_num >= max) - return 0; - - set_bit(bridge_num, &dsa_fwd_offloading_bridges); - } - - return bridge_num; -} - -void dsa_bridge_num_put(const struct net_device *bridge_dev, - unsigned int bridge_num) -{ - /* Since we refcount bridges, we know that when we call this function - * it is no longer in use, so we can just go ahead and remove it from - * the bit mask. - */ - clear_bit(bridge_num, &dsa_fwd_offloading_bridges); -} - -struct dsa_switch *dsa_switch_find(int tree_index, int sw_index) -{ - struct dsa_switch_tree *dst; - struct dsa_port *dp; - - list_for_each_entry(dst, &dsa_tree_list, list) { - if (dst->index != tree_index) - continue; - - list_for_each_entry(dp, &dst->ports, list) { - if (dp->ds->index != sw_index) - continue; - - return dp->ds; - } - } - - return NULL; -} -EXPORT_SYMBOL_GPL(dsa_switch_find); - -static struct dsa_switch_tree *dsa_tree_find(int index) -{ - struct dsa_switch_tree *dst; - - list_for_each_entry(dst, &dsa_tree_list, list) - if (dst->index == index) - return dst; - - return NULL; -} - -static struct dsa_switch_tree *dsa_tree_alloc(int index) -{ - struct dsa_switch_tree *dst; - - dst = kzalloc(sizeof(*dst), GFP_KERNEL); - if (!dst) - return NULL; - - dst->index = index; - - INIT_LIST_HEAD(&dst->rtable); - - INIT_LIST_HEAD(&dst->ports); - - INIT_LIST_HEAD(&dst->list); - list_add_tail(&dst->list, &dsa_tree_list); - - kref_init(&dst->refcount); - - return dst; -} - -static void dsa_tree_free(struct dsa_switch_tree *dst) -{ - if (dst->tag_ops) - dsa_tag_driver_put(dst->tag_ops); - list_del(&dst->list); - kfree(dst); -} - -static struct dsa_switch_tree *dsa_tree_get(struct dsa_switch_tree *dst) -{ - if (dst) - kref_get(&dst->refcount); - - return dst; -} - -static struct dsa_switch_tree *dsa_tree_touch(int index) -{ - struct dsa_switch_tree *dst; - - dst = dsa_tree_find(index); - if (dst) - return dsa_tree_get(dst); - else - return dsa_tree_alloc(index); -} - -static void dsa_tree_release(struct kref *ref) -{ - struct dsa_switch_tree *dst; - - dst = container_of(ref, struct dsa_switch_tree, refcount); - - dsa_tree_free(dst); -} - -static void dsa_tree_put(struct dsa_switch_tree *dst) -{ - if (dst) - kref_put(&dst->refcount, dsa_tree_release); -} - -static struct dsa_port *dsa_tree_find_port_by_node(struct dsa_switch_tree *dst, - struct device_node *dn) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dp->dn == dn) - return dp; - - return NULL; -} - -static struct dsa_link *dsa_link_touch(struct dsa_port *dp, - struct dsa_port *link_dp) -{ - struct dsa_switch *ds = dp->ds; - struct dsa_switch_tree *dst; - struct dsa_link *dl; - - dst = ds->dst; - - list_for_each_entry(dl, &dst->rtable, list) - if (dl->dp == dp && dl->link_dp == link_dp) - return dl; - - dl = kzalloc(sizeof(*dl), GFP_KERNEL); - if (!dl) - return NULL; - - dl->dp = dp; - dl->link_dp = link_dp; - - INIT_LIST_HEAD(&dl->list); - list_add_tail(&dl->list, &dst->rtable); - - return dl; -} - -static bool dsa_port_setup_routing_table(struct dsa_port *dp) -{ - struct dsa_switch *ds = dp->ds; - struct dsa_switch_tree *dst = ds->dst; - struct device_node *dn = dp->dn; - struct of_phandle_iterator it; - struct dsa_port *link_dp; - struct dsa_link *dl; - int err; - - of_for_each_phandle(&it, err, dn, "link", NULL, 0) { - link_dp = dsa_tree_find_port_by_node(dst, it.node); - if (!link_dp) { - of_node_put(it.node); - return false; - } - - dl = dsa_link_touch(dp, link_dp); - if (!dl) { - of_node_put(it.node); - return false; - } - } - - return true; -} - -static bool dsa_tree_setup_routing_table(struct dsa_switch_tree *dst) -{ - bool complete = true; - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) { - if (dsa_port_is_dsa(dp)) { - complete = dsa_port_setup_routing_table(dp); - if (!complete) - break; - } - } - - return complete; -} - -static struct dsa_port *dsa_tree_find_first_cpu(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_is_cpu(dp)) - return dp; - - return NULL; -} - -struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst) -{ - struct device_node *ethernet; - struct net_device *master; - struct dsa_port *cpu_dp; - - cpu_dp = dsa_tree_find_first_cpu(dst); - ethernet = of_parse_phandle(cpu_dp->dn, "ethernet", 0); - master = of_find_net_device_by_node(ethernet); - of_node_put(ethernet); - - return master; -} - -/* Assign the default CPU port (the first one in the tree) to all ports of the - * fabric which don't already have one as part of their own switch. - */ -static int dsa_tree_setup_default_cpu(struct dsa_switch_tree *dst) -{ - struct dsa_port *cpu_dp, *dp; - - cpu_dp = dsa_tree_find_first_cpu(dst); - if (!cpu_dp) { - pr_err("DSA: tree %d has no CPU port\n", dst->index); - return -EINVAL; - } - - list_for_each_entry(dp, &dst->ports, list) { - if (dp->cpu_dp) - continue; - - if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) - dp->cpu_dp = cpu_dp; - } - - return 0; -} - -/* Perform initial assignment of CPU ports to user ports and DSA links in the - * fabric, giving preference to CPU ports local to each switch. Default to - * using the first CPU port in the switch tree if the port does not have a CPU - * port local to this switch. - */ -static int dsa_tree_setup_cpu_ports(struct dsa_switch_tree *dst) -{ - struct dsa_port *cpu_dp, *dp; - - list_for_each_entry(cpu_dp, &dst->ports, list) { - if (!dsa_port_is_cpu(cpu_dp)) - continue; - - /* Prefer a local CPU port */ - dsa_switch_for_each_port(dp, cpu_dp->ds) { - /* Prefer the first local CPU port found */ - if (dp->cpu_dp) - continue; - - if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) - dp->cpu_dp = cpu_dp; - } - } - - return dsa_tree_setup_default_cpu(dst); -} - -static void dsa_tree_teardown_cpu_ports(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) - dp->cpu_dp = NULL; -} - -static int dsa_port_devlink_setup(struct dsa_port *dp) -{ - struct devlink_port *dlp = &dp->devlink_port; - struct dsa_switch_tree *dst = dp->ds->dst; - struct devlink_port_attrs attrs = {}; - struct devlink *dl = dp->ds->devlink; - struct dsa_switch *ds = dp->ds; - const unsigned char *id; - unsigned char len; - int err; - - memset(dlp, 0, sizeof(*dlp)); - devlink_port_init(dl, dlp); - - if (ds->ops->port_setup) { - err = ds->ops->port_setup(ds, dp->index); - if (err) - return err; - } - - id = (const unsigned char *)&dst->index; - len = sizeof(dst->index); - - attrs.phys.port_number = dp->index; - memcpy(attrs.switch_id.id, id, len); - attrs.switch_id.id_len = len; - - switch (dp->type) { - case DSA_PORT_TYPE_UNUSED: - attrs.flavour = DEVLINK_PORT_FLAVOUR_UNUSED; - break; - case DSA_PORT_TYPE_CPU: - attrs.flavour = DEVLINK_PORT_FLAVOUR_CPU; - break; - case DSA_PORT_TYPE_DSA: - attrs.flavour = DEVLINK_PORT_FLAVOUR_DSA; - break; - case DSA_PORT_TYPE_USER: - attrs.flavour = DEVLINK_PORT_FLAVOUR_PHYSICAL; - break; - } - - devlink_port_attrs_set(dlp, &attrs); - err = devlink_port_register(dl, dlp, dp->index); - if (err) { - if (ds->ops->port_teardown) - ds->ops->port_teardown(ds, dp->index); - return err; - } - - return 0; -} - -static void dsa_port_devlink_teardown(struct dsa_port *dp) -{ - struct devlink_port *dlp = &dp->devlink_port; - struct dsa_switch *ds = dp->ds; - - devlink_port_unregister(dlp); - - if (ds->ops->port_teardown) - ds->ops->port_teardown(ds, dp->index); - - devlink_port_fini(dlp); -} - -static int dsa_port_setup(struct dsa_port *dp) -{ - struct devlink_port *dlp = &dp->devlink_port; - bool dsa_port_link_registered = false; - struct dsa_switch *ds = dp->ds; - bool dsa_port_enabled = false; - int err = 0; - - if (dp->setup) - return 0; - - err = dsa_port_devlink_setup(dp); - if (err) - return err; - - switch (dp->type) { - case DSA_PORT_TYPE_UNUSED: - dsa_port_disable(dp); - break; - case DSA_PORT_TYPE_CPU: - if (dp->dn) { - err = dsa_shared_port_link_register_of(dp); - if (err) - break; - dsa_port_link_registered = true; - } else { - dev_warn(ds->dev, - "skipping link registration for CPU port %d\n", - dp->index); - } - - err = dsa_port_enable(dp, NULL); - if (err) - break; - dsa_port_enabled = true; - - break; - case DSA_PORT_TYPE_DSA: - if (dp->dn) { - err = dsa_shared_port_link_register_of(dp); - if (err) - break; - dsa_port_link_registered = true; - } else { - dev_warn(ds->dev, - "skipping link registration for DSA port %d\n", - dp->index); - } - - err = dsa_port_enable(dp, NULL); - if (err) - break; - dsa_port_enabled = true; - - break; - case DSA_PORT_TYPE_USER: - of_get_mac_address(dp->dn, dp->mac); - err = dsa_slave_create(dp); - if (err) - break; - - devlink_port_type_eth_set(dlp, dp->slave); - break; - } - - if (err && dsa_port_enabled) - dsa_port_disable(dp); - if (err && dsa_port_link_registered) - dsa_shared_port_link_unregister_of(dp); - if (err) { - dsa_port_devlink_teardown(dp); - return err; - } - - dp->setup = true; - - return 0; -} - -static void dsa_port_teardown(struct dsa_port *dp) -{ - struct devlink_port *dlp = &dp->devlink_port; - - if (!dp->setup) - return; - - devlink_port_type_clear(dlp); - - switch (dp->type) { - case DSA_PORT_TYPE_UNUSED: - break; - case DSA_PORT_TYPE_CPU: - dsa_port_disable(dp); - if (dp->dn) - dsa_shared_port_link_unregister_of(dp); - break; - case DSA_PORT_TYPE_DSA: - dsa_port_disable(dp); - if (dp->dn) - dsa_shared_port_link_unregister_of(dp); - break; - case DSA_PORT_TYPE_USER: - if (dp->slave) { - dsa_slave_destroy(dp->slave); - dp->slave = NULL; - } - break; - } - - dsa_port_devlink_teardown(dp); - - dp->setup = false; -} - -static int dsa_port_setup_as_unused(struct dsa_port *dp) -{ - dp->type = DSA_PORT_TYPE_UNUSED; - return dsa_port_setup(dp); -} - -static int dsa_devlink_info_get(struct devlink *dl, - struct devlink_info_req *req, - struct netlink_ext_ack *extack) -{ - struct dsa_switch *ds = dsa_devlink_to_ds(dl); - - if (ds->ops->devlink_info_get) - return ds->ops->devlink_info_get(ds, req, extack); - - return -EOPNOTSUPP; -} - -static int dsa_devlink_sb_pool_get(struct devlink *dl, - unsigned int sb_index, u16 pool_index, - struct devlink_sb_pool_info *pool_info) -{ - struct dsa_switch *ds = dsa_devlink_to_ds(dl); - - if (!ds->ops->devlink_sb_pool_get) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_pool_get(ds, sb_index, pool_index, - pool_info); -} - -static int dsa_devlink_sb_pool_set(struct devlink *dl, unsigned int sb_index, - u16 pool_index, u32 size, - enum devlink_sb_threshold_type threshold_type, - struct netlink_ext_ack *extack) -{ - struct dsa_switch *ds = dsa_devlink_to_ds(dl); - - if (!ds->ops->devlink_sb_pool_set) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_pool_set(ds, sb_index, pool_index, size, - threshold_type, extack); -} - -static int dsa_devlink_sb_port_pool_get(struct devlink_port *dlp, - unsigned int sb_index, u16 pool_index, - u32 *p_threshold) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_port_pool_get) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_port_pool_get(ds, port, sb_index, - pool_index, p_threshold); -} - -static int dsa_devlink_sb_port_pool_set(struct devlink_port *dlp, - unsigned int sb_index, u16 pool_index, - u32 threshold, - struct netlink_ext_ack *extack) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_port_pool_set) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_port_pool_set(ds, port, sb_index, - pool_index, threshold, extack); -} - -static int -dsa_devlink_sb_tc_pool_bind_get(struct devlink_port *dlp, - unsigned int sb_index, u16 tc_index, - enum devlink_sb_pool_type pool_type, - u16 *p_pool_index, u32 *p_threshold) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_tc_pool_bind_get) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_tc_pool_bind_get(ds, port, sb_index, - tc_index, pool_type, - p_pool_index, p_threshold); -} - -static int -dsa_devlink_sb_tc_pool_bind_set(struct devlink_port *dlp, - unsigned int sb_index, u16 tc_index, - enum devlink_sb_pool_type pool_type, - u16 pool_index, u32 threshold, - struct netlink_ext_ack *extack) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_tc_pool_bind_set) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_tc_pool_bind_set(ds, port, sb_index, - tc_index, pool_type, - pool_index, threshold, - extack); -} - -static int dsa_devlink_sb_occ_snapshot(struct devlink *dl, - unsigned int sb_index) -{ - struct dsa_switch *ds = dsa_devlink_to_ds(dl); - - if (!ds->ops->devlink_sb_occ_snapshot) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_occ_snapshot(ds, sb_index); -} - -static int dsa_devlink_sb_occ_max_clear(struct devlink *dl, - unsigned int sb_index) -{ - struct dsa_switch *ds = dsa_devlink_to_ds(dl); - - if (!ds->ops->devlink_sb_occ_max_clear) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_occ_max_clear(ds, sb_index); -} - -static int dsa_devlink_sb_occ_port_pool_get(struct devlink_port *dlp, - unsigned int sb_index, - u16 pool_index, u32 *p_cur, - u32 *p_max) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_occ_port_pool_get) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_occ_port_pool_get(ds, port, sb_index, - pool_index, p_cur, p_max); -} - -static int -dsa_devlink_sb_occ_tc_port_bind_get(struct devlink_port *dlp, - unsigned int sb_index, u16 tc_index, - enum devlink_sb_pool_type pool_type, - u32 *p_cur, u32 *p_max) -{ - struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); - int port = dsa_devlink_port_to_port(dlp); - - if (!ds->ops->devlink_sb_occ_tc_port_bind_get) - return -EOPNOTSUPP; - - return ds->ops->devlink_sb_occ_tc_port_bind_get(ds, port, - sb_index, tc_index, - pool_type, p_cur, - p_max); -} - -static const struct devlink_ops dsa_devlink_ops = { - .info_get = dsa_devlink_info_get, - .sb_pool_get = dsa_devlink_sb_pool_get, - .sb_pool_set = dsa_devlink_sb_pool_set, - .sb_port_pool_get = dsa_devlink_sb_port_pool_get, - .sb_port_pool_set = dsa_devlink_sb_port_pool_set, - .sb_tc_pool_bind_get = dsa_devlink_sb_tc_pool_bind_get, - .sb_tc_pool_bind_set = dsa_devlink_sb_tc_pool_bind_set, - .sb_occ_snapshot = dsa_devlink_sb_occ_snapshot, - .sb_occ_max_clear = dsa_devlink_sb_occ_max_clear, - .sb_occ_port_pool_get = dsa_devlink_sb_occ_port_pool_get, - .sb_occ_tc_port_bind_get = dsa_devlink_sb_occ_tc_port_bind_get, -}; - -static int dsa_switch_setup_tag_protocol(struct dsa_switch *ds) -{ - const struct dsa_device_ops *tag_ops = ds->dst->tag_ops; - struct dsa_switch_tree *dst = ds->dst; - int err; - - if (tag_ops->proto == dst->default_proto) - goto connect; - - rtnl_lock(); - err = ds->ops->change_tag_protocol(ds, tag_ops->proto); - rtnl_unlock(); - if (err) { - dev_err(ds->dev, "Unable to use tag protocol \"%s\": %pe\n", - tag_ops->name, ERR_PTR(err)); - return err; - } - -connect: - if (tag_ops->connect) { - err = tag_ops->connect(ds); - if (err) - return err; - } - - if (ds->ops->connect_tag_protocol) { - err = ds->ops->connect_tag_protocol(ds, tag_ops->proto); - if (err) { - dev_err(ds->dev, - "Unable to connect to tag protocol \"%s\": %pe\n", - tag_ops->name, ERR_PTR(err)); - goto disconnect; - } - } - - return 0; - -disconnect: - if (tag_ops->disconnect) - tag_ops->disconnect(ds); - - return err; -} - -static int dsa_switch_setup(struct dsa_switch *ds) -{ - struct dsa_devlink_priv *dl_priv; - struct device_node *dn; - int err; - - if (ds->setup) - return 0; - - /* Initialize ds->phys_mii_mask before registering the slave MDIO bus - * driver and before ops->setup() has run, since the switch drivers and - * the slave MDIO bus driver rely on these values for probing PHY - * devices or not - */ - ds->phys_mii_mask |= dsa_user_ports(ds); - - /* Add the switch to devlink before calling setup, so that setup can - * add dpipe tables - */ - ds->devlink = - devlink_alloc(&dsa_devlink_ops, sizeof(*dl_priv), ds->dev); - if (!ds->devlink) - return -ENOMEM; - dl_priv = devlink_priv(ds->devlink); - dl_priv->ds = ds; - - err = dsa_switch_register_notifier(ds); - if (err) - goto devlink_free; - - ds->configure_vlan_while_not_filtering = true; - - err = ds->ops->setup(ds); - if (err < 0) - goto unregister_notifier; - - err = dsa_switch_setup_tag_protocol(ds); - if (err) - goto teardown; - - if (!ds->slave_mii_bus && ds->ops->phy_read) { - ds->slave_mii_bus = mdiobus_alloc(); - if (!ds->slave_mii_bus) { - err = -ENOMEM; - goto teardown; - } - - dsa_slave_mii_bus_init(ds); - - dn = of_get_child_by_name(ds->dev->of_node, "mdio"); - - err = of_mdiobus_register(ds->slave_mii_bus, dn); - of_node_put(dn); - if (err < 0) - goto free_slave_mii_bus; - } - - ds->setup = true; - devlink_register(ds->devlink); - return 0; - -free_slave_mii_bus: - if (ds->slave_mii_bus && ds->ops->phy_read) - mdiobus_free(ds->slave_mii_bus); -teardown: - if (ds->ops->teardown) - ds->ops->teardown(ds); -unregister_notifier: - dsa_switch_unregister_notifier(ds); -devlink_free: - devlink_free(ds->devlink); - ds->devlink = NULL; - return err; -} - -static void dsa_switch_teardown(struct dsa_switch *ds) -{ - if (!ds->setup) - return; - - if (ds->devlink) - devlink_unregister(ds->devlink); - - if (ds->slave_mii_bus && ds->ops->phy_read) { - mdiobus_unregister(ds->slave_mii_bus); - mdiobus_free(ds->slave_mii_bus); - ds->slave_mii_bus = NULL; - } - - if (ds->ops->teardown) - ds->ops->teardown(ds); - - dsa_switch_unregister_notifier(ds); - - if (ds->devlink) { - devlink_free(ds->devlink); - ds->devlink = NULL; - } - - ds->setup = false; -} - -/* First tear down the non-shared, then the shared ports. This ensures that - * all work items scheduled by our switchdev handlers for user ports have - * completed before we destroy the refcounting kept on the shared ports. - */ -static void dsa_tree_teardown_ports(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) - dsa_port_teardown(dp); - - dsa_flush_workqueue(); - - list_for_each_entry(dp, &dst->ports, list) - if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) - dsa_port_teardown(dp); -} - -static void dsa_tree_teardown_switches(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - dsa_switch_teardown(dp->ds); -} - -/* Bring shared ports up first, then non-shared ports */ -static int dsa_tree_setup_ports(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - int err = 0; - - list_for_each_entry(dp, &dst->ports, list) { - if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) { - err = dsa_port_setup(dp); - if (err) - goto teardown; - } - } - - list_for_each_entry(dp, &dst->ports, list) { - if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) { - err = dsa_port_setup(dp); - if (err) { - err = dsa_port_setup_as_unused(dp); - if (err) - goto teardown; - } - } - } - - return 0; - -teardown: - dsa_tree_teardown_ports(dst); - - return err; -} - -static int dsa_tree_setup_switches(struct dsa_switch_tree *dst) -{ - struct dsa_port *dp; - int err = 0; - - list_for_each_entry(dp, &dst->ports, list) { - err = dsa_switch_setup(dp->ds); - if (err) { - dsa_tree_teardown_switches(dst); - break; - } - } - - return err; -} - -static int dsa_tree_setup_master(struct dsa_switch_tree *dst) -{ - struct dsa_port *cpu_dp; - int err = 0; - - rtnl_lock(); - - dsa_tree_for_each_cpu_port(cpu_dp, dst) { - struct net_device *master = cpu_dp->master; - bool admin_up = (master->flags & IFF_UP) && - !qdisc_tx_is_noop(master); - - err = dsa_master_setup(master, cpu_dp); - if (err) - break; - - /* Replay master state event */ - dsa_tree_master_admin_state_change(dst, master, admin_up); - dsa_tree_master_oper_state_change(dst, master, - netif_oper_up(master)); - } - - rtnl_unlock(); - - return err; -} - -static void dsa_tree_teardown_master(struct dsa_switch_tree *dst) -{ - struct dsa_port *cpu_dp; - - rtnl_lock(); - - dsa_tree_for_each_cpu_port(cpu_dp, dst) { - struct net_device *master = cpu_dp->master; - - /* Synthesizing an "admin down" state is sufficient for - * the switches to get a notification if the master is - * currently up and running. - */ - dsa_tree_master_admin_state_change(dst, master, false); - - dsa_master_teardown(master); - } - - rtnl_unlock(); -} - -static int dsa_tree_setup_lags(struct dsa_switch_tree *dst) -{ - unsigned int len = 0; - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) { - if (dp->ds->num_lag_ids > len) - len = dp->ds->num_lag_ids; - } - - if (!len) - return 0; - - dst->lags = kcalloc(len, sizeof(*dst->lags), GFP_KERNEL); - if (!dst->lags) - return -ENOMEM; - - dst->lags_len = len; - return 0; -} - -static void dsa_tree_teardown_lags(struct dsa_switch_tree *dst) -{ - kfree(dst->lags); -} - -static int dsa_tree_setup(struct dsa_switch_tree *dst) -{ - bool complete; - int err; - - if (dst->setup) { - pr_err("DSA: tree %d already setup! Disjoint trees?\n", - dst->index); - return -EEXIST; - } - - complete = dsa_tree_setup_routing_table(dst); - if (!complete) - return 0; - - err = dsa_tree_setup_cpu_ports(dst); - if (err) - return err; - - err = dsa_tree_setup_switches(dst); - if (err) - goto teardown_cpu_ports; - - err = dsa_tree_setup_ports(dst); - if (err) - goto teardown_switches; - - err = dsa_tree_setup_master(dst); - if (err) - goto teardown_ports; - - err = dsa_tree_setup_lags(dst); - if (err) - goto teardown_master; - - dst->setup = true; - - pr_info("DSA: tree %d setup\n", dst->index); - - return 0; - -teardown_master: - dsa_tree_teardown_master(dst); -teardown_ports: - dsa_tree_teardown_ports(dst); -teardown_switches: - dsa_tree_teardown_switches(dst); -teardown_cpu_ports: - dsa_tree_teardown_cpu_ports(dst); - - return err; -} - -static void dsa_tree_teardown(struct dsa_switch_tree *dst) -{ - struct dsa_link *dl, *next; - - if (!dst->setup) - return; - - dsa_tree_teardown_lags(dst); - - dsa_tree_teardown_master(dst); - - dsa_tree_teardown_ports(dst); - - dsa_tree_teardown_switches(dst); - - dsa_tree_teardown_cpu_ports(dst); - - list_for_each_entry_safe(dl, next, &dst->rtable, list) { - list_del(&dl->list); - kfree(dl); - } - - pr_info("DSA: tree %d torn down\n", dst->index); - - dst->setup = false; -} - -static int dsa_tree_bind_tag_proto(struct dsa_switch_tree *dst, - const struct dsa_device_ops *tag_ops) -{ - const struct dsa_device_ops *old_tag_ops = dst->tag_ops; - struct dsa_notifier_tag_proto_info info; - int err; - - dst->tag_ops = tag_ops; - - /* Notify the switches from this tree about the connection - * to the new tagger - */ - info.tag_ops = tag_ops; - err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_CONNECT, &info); - if (err && err != -EOPNOTSUPP) - goto out_disconnect; - - /* Notify the old tagger about the disconnection from this tree */ - info.tag_ops = old_tag_ops; - dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); - - return 0; - -out_disconnect: - info.tag_ops = tag_ops; - dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); - dst->tag_ops = old_tag_ops; - - return err; -} - -/* Since the dsa/tagging sysfs device attribute is per master, the assumption - * is that all DSA switches within a tree share the same tagger, otherwise - * they would have formed disjoint trees (different "dsa,member" values). - */ -int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, - const struct dsa_device_ops *tag_ops, - const struct dsa_device_ops *old_tag_ops) -{ - struct dsa_notifier_tag_proto_info info; - struct dsa_port *dp; - int err = -EBUSY; - - if (!rtnl_trylock()) - return restart_syscall(); - - /* At the moment we don't allow changing the tag protocol under - * traffic. The rtnl_mutex also happens to serialize concurrent - * attempts to change the tagging protocol. If we ever lift the IFF_UP - * restriction, there needs to be another mutex which serializes this. - */ - dsa_tree_for_each_user_port(dp, dst) { - if (dsa_port_to_master(dp)->flags & IFF_UP) - goto out_unlock; - - if (dp->slave->flags & IFF_UP) - goto out_unlock; - } - - /* Notify the tag protocol change */ - info.tag_ops = tag_ops; - err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); - if (err) - goto out_unwind_tagger; - - err = dsa_tree_bind_tag_proto(dst, tag_ops); - if (err) - goto out_unwind_tagger; - - rtnl_unlock(); - - return 0; - -out_unwind_tagger: - info.tag_ops = old_tag_ops; - dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); -out_unlock: - rtnl_unlock(); - return err; -} - -static void dsa_tree_master_state_change(struct dsa_switch_tree *dst, - struct net_device *master) -{ - struct dsa_notifier_master_state_info info; - struct dsa_port *cpu_dp = master->dsa_ptr; - - info.master = master; - info.operational = dsa_port_master_is_operational(cpu_dp); - - dsa_tree_notify(dst, DSA_NOTIFIER_MASTER_STATE_CHANGE, &info); -} - -void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, - struct net_device *master, - bool up) -{ - struct dsa_port *cpu_dp = master->dsa_ptr; - bool notify = false; - - /* Don't keep track of admin state on LAG DSA masters, - * but rather just of physical DSA masters - */ - if (netif_is_lag_master(master)) - return; - - if ((dsa_port_master_is_operational(cpu_dp)) != - (up && cpu_dp->master_oper_up)) - notify = true; - - cpu_dp->master_admin_up = up; - - if (notify) - dsa_tree_master_state_change(dst, master); -} - -void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, - struct net_device *master, - bool up) -{ - struct dsa_port *cpu_dp = master->dsa_ptr; - bool notify = false; - - /* Don't keep track of oper state on LAG DSA masters, - * but rather just of physical DSA masters - */ - if (netif_is_lag_master(master)) - return; - - if ((dsa_port_master_is_operational(cpu_dp)) != - (cpu_dp->master_admin_up && up)) - notify = true; - - cpu_dp->master_oper_up = up; - - if (notify) - dsa_tree_master_state_change(dst, master); -} - -static struct dsa_port *dsa_port_touch(struct dsa_switch *ds, int index) -{ - struct dsa_switch_tree *dst = ds->dst; - struct dsa_port *dp; - - dsa_switch_for_each_port(dp, ds) - if (dp->index == index) - return dp; - - dp = kzalloc(sizeof(*dp), GFP_KERNEL); - if (!dp) - return NULL; - - dp->ds = ds; - dp->index = index; - - mutex_init(&dp->addr_lists_lock); - mutex_init(&dp->vlans_lock); - INIT_LIST_HEAD(&dp->fdbs); - INIT_LIST_HEAD(&dp->mdbs); - INIT_LIST_HEAD(&dp->vlans); - INIT_LIST_HEAD(&dp->list); - list_add_tail(&dp->list, &dst->ports); - - return dp; -} - -static int dsa_port_parse_user(struct dsa_port *dp, const char *name) -{ - if (!name) - name = "eth%d"; - - dp->type = DSA_PORT_TYPE_USER; - dp->name = name; - - return 0; -} - -static int dsa_port_parse_dsa(struct dsa_port *dp) -{ - dp->type = DSA_PORT_TYPE_DSA; - - return 0; -} - -static enum dsa_tag_protocol dsa_get_tag_protocol(struct dsa_port *dp, - struct net_device *master) -{ - enum dsa_tag_protocol tag_protocol = DSA_TAG_PROTO_NONE; - struct dsa_switch *mds, *ds = dp->ds; - unsigned int mdp_upstream; - struct dsa_port *mdp; - - /* It is possible to stack DSA switches onto one another when that - * happens the switch driver may want to know if its tagging protocol - * is going to work in such a configuration. - */ - if (dsa_slave_dev_check(master)) { - mdp = dsa_slave_to_port(master); - mds = mdp->ds; - mdp_upstream = dsa_upstream_port(mds, mdp->index); - tag_protocol = mds->ops->get_tag_protocol(mds, mdp_upstream, - DSA_TAG_PROTO_NONE); - } - - /* If the master device is not itself a DSA slave in a disjoint DSA - * tree, then return immediately. - */ - return ds->ops->get_tag_protocol(ds, dp->index, tag_protocol); -} - -static int dsa_port_parse_cpu(struct dsa_port *dp, struct net_device *master, - const char *user_protocol) -{ - const struct dsa_device_ops *tag_ops = NULL; - struct dsa_switch *ds = dp->ds; - struct dsa_switch_tree *dst = ds->dst; - enum dsa_tag_protocol default_proto; - - /* Find out which protocol the switch would prefer. */ - default_proto = dsa_get_tag_protocol(dp, master); - if (dst->default_proto) { - if (dst->default_proto != default_proto) { - dev_err(ds->dev, - "A DSA switch tree can have only one tagging protocol\n"); - return -EINVAL; - } - } else { - dst->default_proto = default_proto; - } - - /* See if the user wants to override that preference. */ - if (user_protocol) { - if (!ds->ops->change_tag_protocol) { - dev_err(ds->dev, "Tag protocol cannot be modified\n"); - return -EINVAL; - } - - tag_ops = dsa_find_tagger_by_name(user_protocol); - if (IS_ERR(tag_ops)) { - dev_warn(ds->dev, - "Failed to find a tagging driver for protocol %s, using default\n", - user_protocol); - tag_ops = NULL; - } - } - - if (!tag_ops) - tag_ops = dsa_tag_driver_get(default_proto); - - if (IS_ERR(tag_ops)) { - if (PTR_ERR(tag_ops) == -ENOPROTOOPT) - return -EPROBE_DEFER; - - dev_warn(ds->dev, "No tagger for this switch\n"); - return PTR_ERR(tag_ops); - } - - if (dst->tag_ops) { - if (dst->tag_ops != tag_ops) { - dev_err(ds->dev, - "A DSA switch tree can have only one tagging protocol\n"); - - dsa_tag_driver_put(tag_ops); - return -EINVAL; - } - - /* In the case of multiple CPU ports per switch, the tagging - * protocol is still reference-counted only per switch tree. - */ - dsa_tag_driver_put(tag_ops); - } else { - dst->tag_ops = tag_ops; - } - - dp->master = master; - dp->type = DSA_PORT_TYPE_CPU; - dsa_port_set_tag_protocol(dp, dst->tag_ops); - dp->dst = dst; - - /* At this point, the tree may be configured to use a different - * tagger than the one chosen by the switch driver during - * .setup, in the case when a user selects a custom protocol - * through the DT. - * - * This is resolved by syncing the driver with the tree in - * dsa_switch_setup_tag_protocol once .setup has run and the - * driver is ready to accept calls to .change_tag_protocol. If - * the driver does not support the custom protocol at that - * point, the tree is wholly rejected, thereby ensuring that the - * tree and driver are always in agreement on the protocol to - * use. - */ - return 0; -} - -static int dsa_port_parse_of(struct dsa_port *dp, struct device_node *dn) -{ - struct device_node *ethernet = of_parse_phandle(dn, "ethernet", 0); - const char *name = of_get_property(dn, "label", NULL); - bool link = of_property_read_bool(dn, "link"); - - dp->dn = dn; - - if (ethernet) { - struct net_device *master; - const char *user_protocol; - - master = of_find_net_device_by_node(ethernet); - of_node_put(ethernet); - if (!master) - return -EPROBE_DEFER; - - user_protocol = of_get_property(dn, "dsa-tag-protocol", NULL); - return dsa_port_parse_cpu(dp, master, user_protocol); - } - - if (link) - return dsa_port_parse_dsa(dp); - - return dsa_port_parse_user(dp, name); -} - -static int dsa_switch_parse_ports_of(struct dsa_switch *ds, - struct device_node *dn) -{ - struct device_node *ports, *port; - struct dsa_port *dp; - int err = 0; - u32 reg; - - ports = of_get_child_by_name(dn, "ports"); - if (!ports) { - /* The second possibility is "ethernet-ports" */ - ports = of_get_child_by_name(dn, "ethernet-ports"); - if (!ports) { - dev_err(ds->dev, "no ports child node found\n"); - return -EINVAL; - } - } - - for_each_available_child_of_node(ports, port) { - err = of_property_read_u32(port, "reg", ®); - if (err) { - of_node_put(port); - goto out_put_node; - } - - if (reg >= ds->num_ports) { - dev_err(ds->dev, "port %pOF index %u exceeds num_ports (%u)\n", - port, reg, ds->num_ports); - of_node_put(port); - err = -EINVAL; - goto out_put_node; - } - - dp = dsa_to_port(ds, reg); - - err = dsa_port_parse_of(dp, port); - if (err) { - of_node_put(port); - goto out_put_node; - } - } - -out_put_node: - of_node_put(ports); - return err; -} - -static int dsa_switch_parse_member_of(struct dsa_switch *ds, - struct device_node *dn) -{ - u32 m[2] = { 0, 0 }; - int sz; - - /* Don't error out if this optional property isn't found */ - sz = of_property_read_variable_u32_array(dn, "dsa,member", m, 2, 2); - if (sz < 0 && sz != -EINVAL) - return sz; - - ds->index = m[1]; - - ds->dst = dsa_tree_touch(m[0]); - if (!ds->dst) - return -ENOMEM; - - if (dsa_switch_find(ds->dst->index, ds->index)) { - dev_err(ds->dev, - "A DSA switch with index %d already exists in tree %d\n", - ds->index, ds->dst->index); - return -EEXIST; - } - - if (ds->dst->last_switch < ds->index) - ds->dst->last_switch = ds->index; - - return 0; -} - -static int dsa_switch_touch_ports(struct dsa_switch *ds) -{ - struct dsa_port *dp; - int port; - - for (port = 0; port < ds->num_ports; port++) { - dp = dsa_port_touch(ds, port); - if (!dp) - return -ENOMEM; - } - - return 0; -} - -static int dsa_switch_parse_of(struct dsa_switch *ds, struct device_node *dn) -{ - int err; - - err = dsa_switch_parse_member_of(ds, dn); - if (err) - return err; - - err = dsa_switch_touch_ports(ds); - if (err) - return err; - - return dsa_switch_parse_ports_of(ds, dn); -} - -static int dsa_port_parse(struct dsa_port *dp, const char *name, - struct device *dev) -{ - if (!strcmp(name, "cpu")) { - struct net_device *master; - - master = dsa_dev_to_net_device(dev); - if (!master) - return -EPROBE_DEFER; - - dev_put(master); - - return dsa_port_parse_cpu(dp, master, NULL); - } - - if (!strcmp(name, "dsa")) - return dsa_port_parse_dsa(dp); - - return dsa_port_parse_user(dp, name); -} - -static int dsa_switch_parse_ports(struct dsa_switch *ds, - struct dsa_chip_data *cd) -{ - bool valid_name_found = false; - struct dsa_port *dp; - struct device *dev; - const char *name; - unsigned int i; - int err; - - for (i = 0; i < DSA_MAX_PORTS; i++) { - name = cd->port_names[i]; - dev = cd->netdev[i]; - dp = dsa_to_port(ds, i); - - if (!name) - continue; - - err = dsa_port_parse(dp, name, dev); - if (err) - return err; - - valid_name_found = true; - } - - if (!valid_name_found && i == DSA_MAX_PORTS) - return -EINVAL; - - return 0; -} - -static int dsa_switch_parse(struct dsa_switch *ds, struct dsa_chip_data *cd) -{ - int err; - - ds->cd = cd; - - /* We don't support interconnected switches nor multiple trees via - * platform data, so this is the unique switch of the tree. - */ - ds->index = 0; - ds->dst = dsa_tree_touch(0); - if (!ds->dst) - return -ENOMEM; - - err = dsa_switch_touch_ports(ds); - if (err) - return err; - - return dsa_switch_parse_ports(ds, cd); -} - -static void dsa_switch_release_ports(struct dsa_switch *ds) -{ - struct dsa_port *dp, *next; - - dsa_switch_for_each_port_safe(dp, next, ds) { - WARN_ON(!list_empty(&dp->fdbs)); - WARN_ON(!list_empty(&dp->mdbs)); - WARN_ON(!list_empty(&dp->vlans)); - list_del(&dp->list); - kfree(dp); - } -} - -static int dsa_switch_probe(struct dsa_switch *ds) -{ - struct dsa_switch_tree *dst; - struct dsa_chip_data *pdata; - struct device_node *np; - int err; - - if (!ds->dev) - return -ENODEV; - - pdata = ds->dev->platform_data; - np = ds->dev->of_node; - - if (!ds->num_ports) - return -EINVAL; - - if (np) { - err = dsa_switch_parse_of(ds, np); - if (err) - dsa_switch_release_ports(ds); - } else if (pdata) { - err = dsa_switch_parse(ds, pdata); - if (err) - dsa_switch_release_ports(ds); - } else { - err = -ENODEV; - } - - if (err) - return err; - - dst = ds->dst; - dsa_tree_get(dst); - err = dsa_tree_setup(dst); - if (err) { - dsa_switch_release_ports(ds); - dsa_tree_put(dst); - } - - return err; -} - -int dsa_register_switch(struct dsa_switch *ds) -{ - int err; - - mutex_lock(&dsa2_mutex); - err = dsa_switch_probe(ds); - dsa_tree_put(ds->dst); - mutex_unlock(&dsa2_mutex); - - return err; -} -EXPORT_SYMBOL_GPL(dsa_register_switch); - -static void dsa_switch_remove(struct dsa_switch *ds) -{ - struct dsa_switch_tree *dst = ds->dst; - - dsa_tree_teardown(dst); - dsa_switch_release_ports(ds); - dsa_tree_put(dst); -} - -void dsa_unregister_switch(struct dsa_switch *ds) -{ - mutex_lock(&dsa2_mutex); - dsa_switch_remove(ds); - mutex_unlock(&dsa2_mutex); -} -EXPORT_SYMBOL_GPL(dsa_unregister_switch); - -/* If the DSA master chooses to unregister its net_device on .shutdown, DSA is - * blocking that operation from completion, due to the dev_hold taken inside - * netdev_upper_dev_link. Unlink the DSA slave interfaces from being uppers of - * the DSA master, so that the system can reboot successfully. - */ -void dsa_switch_shutdown(struct dsa_switch *ds) -{ - struct net_device *master, *slave_dev; - struct dsa_port *dp; - - mutex_lock(&dsa2_mutex); - - if (!ds->setup) - goto out; - - rtnl_lock(); - - dsa_switch_for_each_user_port(dp, ds) { - master = dsa_port_to_master(dp); - slave_dev = dp->slave; - - netdev_upper_dev_unlink(master, slave_dev); - } - - /* Disconnect from further netdevice notifiers on the master, - * since netdev_uses_dsa() will now return false. - */ - dsa_switch_for_each_cpu_port(dp, ds) - dp->master->dsa_ptr = NULL; - - rtnl_unlock(); -out: - mutex_unlock(&dsa2_mutex); -} -EXPORT_SYMBOL_GPL(dsa_switch_shutdown); diff --git a/net/dsa/dsa_priv.h b/net/dsa/dsa_priv.h deleted file mode 100644 index 6e65c7ffd6f3..000000000000 --- a/net/dsa/dsa_priv.h +++ /dev/null @@ -1,587 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0-or-later */ -/* - * net/dsa/dsa_priv.h - Hardware switch handling - * Copyright (c) 2008-2009 Marvell Semiconductor - */ - -#ifndef __DSA_PRIV_H -#define __DSA_PRIV_H - -#include <linux/if_bridge.h> -#include <linux/if_vlan.h> -#include <linux/phy.h> -#include <linux/netdevice.h> -#include <linux/netpoll.h> -#include <net/dsa.h> -#include <net/gro_cells.h> - -#define DSA_MAX_NUM_OFFLOADING_BRIDGES BITS_PER_LONG - -enum { - DSA_NOTIFIER_AGEING_TIME, - DSA_NOTIFIER_BRIDGE_JOIN, - DSA_NOTIFIER_BRIDGE_LEAVE, - DSA_NOTIFIER_FDB_ADD, - DSA_NOTIFIER_FDB_DEL, - DSA_NOTIFIER_HOST_FDB_ADD, - DSA_NOTIFIER_HOST_FDB_DEL, - DSA_NOTIFIER_LAG_FDB_ADD, - DSA_NOTIFIER_LAG_FDB_DEL, - DSA_NOTIFIER_LAG_CHANGE, - DSA_NOTIFIER_LAG_JOIN, - DSA_NOTIFIER_LAG_LEAVE, - DSA_NOTIFIER_MDB_ADD, - DSA_NOTIFIER_MDB_DEL, - DSA_NOTIFIER_HOST_MDB_ADD, - DSA_NOTIFIER_HOST_MDB_DEL, - DSA_NOTIFIER_VLAN_ADD, - DSA_NOTIFIER_VLAN_DEL, - DSA_NOTIFIER_HOST_VLAN_ADD, - DSA_NOTIFIER_HOST_VLAN_DEL, - DSA_NOTIFIER_MTU, - DSA_NOTIFIER_TAG_PROTO, - DSA_NOTIFIER_TAG_PROTO_CONNECT, - DSA_NOTIFIER_TAG_PROTO_DISCONNECT, - DSA_NOTIFIER_TAG_8021Q_VLAN_ADD, - DSA_NOTIFIER_TAG_8021Q_VLAN_DEL, - DSA_NOTIFIER_MASTER_STATE_CHANGE, -}; - -/* DSA_NOTIFIER_AGEING_TIME */ -struct dsa_notifier_ageing_time_info { - unsigned int ageing_time; -}; - -/* DSA_NOTIFIER_BRIDGE_* */ -struct dsa_notifier_bridge_info { - const struct dsa_port *dp; - struct dsa_bridge bridge; - bool tx_fwd_offload; - struct netlink_ext_ack *extack; -}; - -/* DSA_NOTIFIER_FDB_* */ -struct dsa_notifier_fdb_info { - const struct dsa_port *dp; - const unsigned char *addr; - u16 vid; - struct dsa_db db; -}; - -/* DSA_NOTIFIER_LAG_FDB_* */ -struct dsa_notifier_lag_fdb_info { - struct dsa_lag *lag; - const unsigned char *addr; - u16 vid; - struct dsa_db db; -}; - -/* DSA_NOTIFIER_MDB_* */ -struct dsa_notifier_mdb_info { - const struct dsa_port *dp; - const struct switchdev_obj_port_mdb *mdb; - struct dsa_db db; -}; - -/* DSA_NOTIFIER_LAG_* */ -struct dsa_notifier_lag_info { - const struct dsa_port *dp; - struct dsa_lag lag; - struct netdev_lag_upper_info *info; - struct netlink_ext_ack *extack; -}; - -/* DSA_NOTIFIER_VLAN_* */ -struct dsa_notifier_vlan_info { - const struct dsa_port *dp; - const struct switchdev_obj_port_vlan *vlan; - struct netlink_ext_ack *extack; -}; - -/* DSA_NOTIFIER_MTU */ -struct dsa_notifier_mtu_info { - const struct dsa_port *dp; - int mtu; -}; - -/* DSA_NOTIFIER_TAG_PROTO_* */ -struct dsa_notifier_tag_proto_info { - const struct dsa_device_ops *tag_ops; -}; - -/* DSA_NOTIFIER_TAG_8021Q_VLAN_* */ -struct dsa_notifier_tag_8021q_vlan_info { - const struct dsa_port *dp; - u16 vid; -}; - -/* DSA_NOTIFIER_MASTER_STATE_CHANGE */ -struct dsa_notifier_master_state_info { - const struct net_device *master; - bool operational; -}; - -struct dsa_switchdev_event_work { - struct net_device *dev; - struct net_device *orig_dev; - struct work_struct work; - unsigned long event; - /* Specific for SWITCHDEV_FDB_ADD_TO_DEVICE and - * SWITCHDEV_FDB_DEL_TO_DEVICE - */ - unsigned char addr[ETH_ALEN]; - u16 vid; - bool host_addr; -}; - -enum dsa_standalone_event { - DSA_UC_ADD, - DSA_UC_DEL, - DSA_MC_ADD, - DSA_MC_DEL, -}; - -struct dsa_standalone_event_work { - struct work_struct work; - struct net_device *dev; - enum dsa_standalone_event event; - unsigned char addr[ETH_ALEN]; - u16 vid; -}; - -struct dsa_slave_priv { - /* Copy of CPU port xmit for faster access in slave transmit hot path */ - struct sk_buff * (*xmit)(struct sk_buff *skb, - struct net_device *dev); - - struct gro_cells gcells; - - /* DSA port data, such as switch, port index, etc. */ - struct dsa_port *dp; - -#ifdef CONFIG_NET_POLL_CONTROLLER - struct netpoll *netpoll; -#endif - - /* TC context */ - struct list_head mall_tc_list; -}; - -/* dsa.c */ -const struct dsa_device_ops *dsa_tag_driver_get(int tag_protocol); -void dsa_tag_driver_put(const struct dsa_device_ops *ops); -const struct dsa_device_ops *dsa_find_tagger_by_name(const char *buf); - -bool dsa_db_equal(const struct dsa_db *a, const struct dsa_db *b); - -bool dsa_schedule_work(struct work_struct *work); -const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops); - -static inline int dsa_tag_protocol_overhead(const struct dsa_device_ops *ops) -{ - return ops->needed_headroom + ops->needed_tailroom; -} - -/* master.c */ -int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp); -void dsa_master_teardown(struct net_device *dev); -int dsa_master_lag_setup(struct net_device *lag_dev, struct dsa_port *cpu_dp, - struct netdev_lag_upper_info *uinfo, - struct netlink_ext_ack *extack); -void dsa_master_lag_teardown(struct net_device *lag_dev, - struct dsa_port *cpu_dp); - -static inline struct net_device *dsa_master_find_slave(struct net_device *dev, - int device, int port) -{ - struct dsa_port *cpu_dp = dev->dsa_ptr; - struct dsa_switch_tree *dst = cpu_dp->dst; - struct dsa_port *dp; - - list_for_each_entry(dp, &dst->ports, list) - if (dp->ds->index == device && dp->index == port && - dp->type == DSA_PORT_TYPE_USER) - return dp->slave; - - return NULL; -} - -/* netlink.c */ -extern struct rtnl_link_ops dsa_link_ops __read_mostly; - -/* port.c */ -void dsa_port_set_tag_protocol(struct dsa_port *cpu_dp, - const struct dsa_device_ops *tag_ops); -int dsa_port_set_state(struct dsa_port *dp, u8 state, bool do_fast_age); -int dsa_port_set_mst_state(struct dsa_port *dp, - const struct switchdev_mst_state *state, - struct netlink_ext_ack *extack); -int dsa_port_enable_rt(struct dsa_port *dp, struct phy_device *phy); -int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy); -void dsa_port_disable_rt(struct dsa_port *dp); -void dsa_port_disable(struct dsa_port *dp); -int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br, - struct netlink_ext_ack *extack); -void dsa_port_pre_bridge_leave(struct dsa_port *dp, struct net_device *br); -void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br); -int dsa_port_lag_change(struct dsa_port *dp, - struct netdev_lag_lower_state_info *linfo); -int dsa_port_lag_join(struct dsa_port *dp, struct net_device *lag_dev, - struct netdev_lag_upper_info *uinfo, - struct netlink_ext_ack *extack); -void dsa_port_pre_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); -void dsa_port_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); -int dsa_port_vlan_filtering(struct dsa_port *dp, bool vlan_filtering, - struct netlink_ext_ack *extack); -bool dsa_port_skip_vlan_configuration(struct dsa_port *dp); -int dsa_port_ageing_time(struct dsa_port *dp, clock_t ageing_clock); -int dsa_port_mst_enable(struct dsa_port *dp, bool on, - struct netlink_ext_ack *extack); -int dsa_port_vlan_msti(struct dsa_port *dp, - const struct switchdev_vlan_msti *msti); -int dsa_port_mtu_change(struct dsa_port *dp, int new_mtu); -int dsa_port_fdb_add(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_fdb_del(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_standalone_host_fdb_add(struct dsa_port *dp, - const unsigned char *addr, u16 vid); -int dsa_port_standalone_host_fdb_del(struct dsa_port *dp, - const unsigned char *addr, u16 vid); -int dsa_port_bridge_host_fdb_add(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_bridge_host_fdb_del(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_lag_fdb_add(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_lag_fdb_del(struct dsa_port *dp, const unsigned char *addr, - u16 vid); -int dsa_port_fdb_dump(struct dsa_port *dp, dsa_fdb_dump_cb_t *cb, void *data); -int dsa_port_mdb_add(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_mdb_del(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_standalone_host_mdb_add(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_standalone_host_mdb_del(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_bridge_host_mdb_add(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_bridge_host_mdb_del(const struct dsa_port *dp, - const struct switchdev_obj_port_mdb *mdb); -int dsa_port_pre_bridge_flags(const struct dsa_port *dp, - struct switchdev_brport_flags flags, - struct netlink_ext_ack *extack); -int dsa_port_bridge_flags(struct dsa_port *dp, - struct switchdev_brport_flags flags, - struct netlink_ext_ack *extack); -int dsa_port_vlan_add(struct dsa_port *dp, - const struct switchdev_obj_port_vlan *vlan, - struct netlink_ext_ack *extack); -int dsa_port_vlan_del(struct dsa_port *dp, - const struct switchdev_obj_port_vlan *vlan); -int dsa_port_host_vlan_add(struct dsa_port *dp, - const struct switchdev_obj_port_vlan *vlan, - struct netlink_ext_ack *extack); -int dsa_port_host_vlan_del(struct dsa_port *dp, - const struct switchdev_obj_port_vlan *vlan); -int dsa_port_mrp_add(const struct dsa_port *dp, - const struct switchdev_obj_mrp *mrp); -int dsa_port_mrp_del(const struct dsa_port *dp, - const struct switchdev_obj_mrp *mrp); -int dsa_port_mrp_add_ring_role(const struct dsa_port *dp, - const struct switchdev_obj_ring_role_mrp *mrp); -int dsa_port_mrp_del_ring_role(const struct dsa_port *dp, - const struct switchdev_obj_ring_role_mrp *mrp); -int dsa_port_phylink_create(struct dsa_port *dp); -void dsa_port_phylink_destroy(struct dsa_port *dp); -int dsa_shared_port_link_register_of(struct dsa_port *dp); -void dsa_shared_port_link_unregister_of(struct dsa_port *dp); -int dsa_port_hsr_join(struct dsa_port *dp, struct net_device *hsr); -void dsa_port_hsr_leave(struct dsa_port *dp, struct net_device *hsr); -int dsa_port_tag_8021q_vlan_add(struct dsa_port *dp, u16 vid, bool broadcast); -void dsa_port_tag_8021q_vlan_del(struct dsa_port *dp, u16 vid, bool broadcast); -void dsa_port_set_host_flood(struct dsa_port *dp, bool uc, bool mc); -int dsa_port_change_master(struct dsa_port *dp, struct net_device *master, - struct netlink_ext_ack *extack); - -/* slave.c */ -extern const struct dsa_device_ops notag_netdev_ops; -extern struct notifier_block dsa_slave_switchdev_notifier; -extern struct notifier_block dsa_slave_switchdev_blocking_notifier; - -void dsa_slave_mii_bus_init(struct dsa_switch *ds); -int dsa_slave_create(struct dsa_port *dp); -void dsa_slave_destroy(struct net_device *slave_dev); -int dsa_slave_suspend(struct net_device *slave_dev); -int dsa_slave_resume(struct net_device *slave_dev); -int dsa_slave_register_notifier(void); -void dsa_slave_unregister_notifier(void); -void dsa_slave_sync_ha(struct net_device *dev); -void dsa_slave_unsync_ha(struct net_device *dev); -void dsa_slave_setup_tagger(struct net_device *slave); -int dsa_slave_change_mtu(struct net_device *dev, int new_mtu); -int dsa_slave_change_master(struct net_device *dev, struct net_device *master, - struct netlink_ext_ack *extack); -int dsa_slave_manage_vlan_filtering(struct net_device *dev, - bool vlan_filtering); - -static inline struct dsa_port *dsa_slave_to_port(const struct net_device *dev) -{ - struct dsa_slave_priv *p = netdev_priv(dev); - - return p->dp; -} - -static inline struct net_device * -dsa_slave_to_master(const struct net_device *dev) -{ - struct dsa_port *dp = dsa_slave_to_port(dev); - - return dsa_port_to_master(dp); -} - -/* If under a bridge with vlan_filtering=0, make sure to send pvid-tagged - * frames as untagged, since the bridge will not untag them. - */ -static inline struct sk_buff *dsa_untag_bridge_pvid(struct sk_buff *skb) -{ - struct dsa_port *dp = dsa_slave_to_port(skb->dev); - struct net_device *br = dsa_port_bridge_dev_get(dp); - struct net_device *dev = skb->dev; - struct net_device *upper_dev; - u16 vid, pvid, proto; - int err; - - if (!br || br_vlan_enabled(br)) - return skb; - - err = br_vlan_get_proto(br, &proto); - if (err) - return skb; - - /* Move VLAN tag from data to hwaccel */ - if (!skb_vlan_tag_present(skb) && skb->protocol == htons(proto)) { - skb = skb_vlan_untag(skb); - if (!skb) - return NULL; - } - - if (!skb_vlan_tag_present(skb)) - return skb; - - vid = skb_vlan_tag_get_id(skb); - - /* We already run under an RCU read-side critical section since - * we are called from netif_receive_skb_list_internal(). - */ - err = br_vlan_get_pvid_rcu(dev, &pvid); - if (err) - return skb; - - if (vid != pvid) - return skb; - - /* The sad part about attempting to untag from DSA is that we - * don't know, unless we check, if the skb will end up in - * the bridge's data path - br_allowed_ingress() - or not. - * For example, there might be an 8021q upper for the - * default_pvid of the bridge, which will steal VLAN-tagged traffic - * from the bridge's data path. This is a configuration that DSA - * supports because vlan_filtering is 0. In that case, we should - * definitely keep the tag, to make sure it keeps working. - */ - upper_dev = __vlan_find_dev_deep_rcu(br, htons(proto), vid); - if (upper_dev) - return skb; - - __vlan_hwaccel_clear_tag(skb); - - return skb; -} - -/* For switches without hardware support for DSA tagging to be able - * to support termination through the bridge. - */ -static inline struct net_device * -dsa_find_designated_bridge_port_by_vid(struct net_device *master, u16 vid) -{ - struct dsa_port *cpu_dp = master->dsa_ptr; - struct dsa_switch_tree *dst = cpu_dp->dst; - struct bridge_vlan_info vinfo; - struct net_device *slave; - struct dsa_port *dp; - int err; - - list_for_each_entry(dp, &dst->ports, list) { - if (dp->type != DSA_PORT_TYPE_USER) - continue; - - if (!dp->bridge) - continue; - - if (dp->stp_state != BR_STATE_LEARNING && - dp->stp_state != BR_STATE_FORWARDING) - continue; - - /* Since the bridge might learn this packet, keep the CPU port - * affinity with the port that will be used for the reply on - * xmit. - */ - if (dp->cpu_dp != cpu_dp) - continue; - - slave = dp->slave; - - err = br_vlan_get_info_rcu(slave, vid, &vinfo); - if (err) - continue; - - return slave; - } - - return NULL; -} - -/* If the ingress port offloads the bridge, we mark the frame as autonomously - * forwarded by hardware, so the software bridge doesn't forward in twice, back - * to us, because we already did. However, if we're in fallback mode and we do - * software bridging, we are not offloading it, therefore the dp->bridge - * pointer is not populated, and flooding needs to be done by software (we are - * effectively operating in standalone ports mode). - */ -static inline void dsa_default_offload_fwd_mark(struct sk_buff *skb) -{ - struct dsa_port *dp = dsa_slave_to_port(skb->dev); - - skb->offload_fwd_mark = !!(dp->bridge); -} - -/* Helper for removing DSA header tags from packets in the RX path. - * Must not be called before skb_pull(len). - * skb->data - * | - * v - * | | | | | | | | | | | | | | | | | | | - * +-----------------------+-----------------------+---------------+-------+ - * | Destination MAC | Source MAC | DSA header | EType | - * +-----------------------+-----------------------+---------------+-------+ - * | | - * <----- len -----> <----- len -----> - * | - * >>>>>>> v - * >>>>>>> | | | | | | | | | | | | | | | - * >>>>>>> +-----------------------+-----------------------+-------+ - * >>>>>>> | Destination MAC | Source MAC | EType | - * +-----------------------+-----------------------+-------+ - * ^ - * | - * skb->data - */ -static inline void dsa_strip_etype_header(struct sk_buff *skb, int len) -{ - memmove(skb->data - ETH_HLEN, skb->data - ETH_HLEN - len, 2 * ETH_ALEN); -} - -/* Helper for creating space for DSA header tags in TX path packets. - * Must not be called before skb_push(len). - * - * Before: - * - * <<<<<<< | | | | | | | | | | | | | | | - * ^ <<<<<<< +-----------------------+-----------------------+-------+ - * | <<<<<<< | Destination MAC | Source MAC | EType | - * | +-----------------------+-----------------------+-------+ - * <----- len -----> - * | - * | - * skb->data - * - * After: - * - * | | | | | | | | | | | | | | | | | | | - * +-----------------------+-----------------------+---------------+-------+ - * | Destination MAC | Source MAC | DSA header | EType | - * +-----------------------+-----------------------+---------------+-------+ - * ^ | | - * | <----- len -----> - * skb->data - */ -static inline void dsa_alloc_etype_header(struct sk_buff *skb, int len) -{ - memmove(skb->data, skb->data + len, 2 * ETH_ALEN); -} - -/* On RX, eth_type_trans() on the DSA master pulls ETH_HLEN bytes starting from - * skb_mac_header(skb), which leaves skb->data pointing at the first byte after - * what the DSA master perceives as the EtherType (the beginning of the L3 - * protocol). Since DSA EtherType header taggers treat the EtherType as part of - * the DSA tag itself, and the EtherType is 2 bytes in length, the DSA header - * is located 2 bytes behind skb->data. Note that EtherType in this context - * means the first 2 bytes of the DSA header, not the encapsulated EtherType - * that will become visible after the DSA header is stripped. - */ -static inline void *dsa_etype_header_pos_rx(struct sk_buff *skb) -{ - return skb->data - 2; -} - -/* On TX, skb->data points to skb_mac_header(skb), which means that EtherType - * header taggers start exactly where the EtherType is (the EtherType is - * treated as part of the DSA header). - */ -static inline void *dsa_etype_header_pos_tx(struct sk_buff *skb) -{ - return skb->data + 2 * ETH_ALEN; -} - -/* switch.c */ -int dsa_switch_register_notifier(struct dsa_switch *ds); -void dsa_switch_unregister_notifier(struct dsa_switch *ds); - -static inline bool dsa_switch_supports_uc_filtering(struct dsa_switch *ds) -{ - return ds->ops->port_fdb_add && ds->ops->port_fdb_del && - ds->fdb_isolation && !ds->vlan_filtering_is_global && - !ds->needs_standalone_vlan_filtering; -} - -static inline bool dsa_switch_supports_mc_filtering(struct dsa_switch *ds) -{ - return ds->ops->port_mdb_add && ds->ops->port_mdb_del && - ds->fdb_isolation && !ds->vlan_filtering_is_global && - !ds->needs_standalone_vlan_filtering; -} - -/* dsa2.c */ -void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag); -void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag); -struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, - const struct net_device *lag_dev); -struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst); -int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v); -int dsa_broadcast(unsigned long e, void *v); -int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, - const struct dsa_device_ops *tag_ops, - const struct dsa_device_ops *old_tag_ops); -void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, - struct net_device *master, - bool up); -void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, - struct net_device *master, - bool up); -unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max); -void dsa_bridge_num_put(const struct net_device *bridge_dev, - unsigned int bridge_num); -struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, - const struct net_device *br); - -/* tag_8021q.c */ -int dsa_switch_tag_8021q_vlan_add(struct dsa_switch *ds, - struct dsa_notifier_tag_8021q_vlan_info *info); -int dsa_switch_tag_8021q_vlan_del(struct dsa_switch *ds, - struct dsa_notifier_tag_8021q_vlan_info *info); - -extern struct list_head dsa_tree_list; - -#endif diff --git a/net/dsa/master.c b/net/dsa/master.c index 40367ab41cf8..26d90140d271 100644 --- a/net/dsa/master.c +++ b/net/dsa/master.c @@ -6,7 +6,15 @@ * Vivien Didelot <vivien.didelot@savoirfairelinux.com> */ -#include "dsa_priv.h" +#include <linux/ethtool.h> +#include <linux/netdevice.h> +#include <linux/netlink.h> +#include <net/dsa.h> + +#include "dsa.h" +#include "master.h" +#include "port.h" +#include "tag.h" static int dsa_master_get_regs_len(struct net_device *dev) { @@ -204,8 +212,7 @@ static int dsa_master_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) * switch in the tree that is PTP capable. */ list_for_each_entry(dp, &dst->ports, list) - if (dp->ds->ops->port_hwtstamp_get || - dp->ds->ops->port_hwtstamp_set) + if (dsa_port_supports_hwtstamp(dp, ifr)) return -EBUSY; break; } @@ -300,13 +307,24 @@ static ssize_t tagging_store(struct device *d, struct device_attribute *attr, const char *buf, size_t count) { const struct dsa_device_ops *new_tag_ops, *old_tag_ops; + const char *end = strchrnul(buf, '\n'), *name; struct net_device *dev = to_net_dev(d); struct dsa_port *cpu_dp = dev->dsa_ptr; + size_t len = end - buf; int err; + /* Empty string passed */ + if (!len) + return -ENOPROTOOPT; + + name = kstrndup(buf, len, GFP_KERNEL); + if (!name) + return -ENOMEM; + old_tag_ops = cpu_dp->tag_ops; - new_tag_ops = dsa_find_tagger_by_name(buf); - /* Bad tagger name, or module is not loaded? */ + new_tag_ops = dsa_tag_driver_get_by_name(name); + kfree(name); + /* Bad tagger name? */ if (IS_ERR(new_tag_ops)) return PTR_ERR(new_tag_ops); diff --git a/net/dsa/master.h b/net/dsa/master.h new file mode 100644 index 000000000000..3fc0e610b5b5 --- /dev/null +++ b/net/dsa/master.h @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_MASTER_H +#define __DSA_MASTER_H + +struct dsa_port; +struct net_device; +struct netdev_lag_upper_info; +struct netlink_ext_ack; + +int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp); +void dsa_master_teardown(struct net_device *dev); +int dsa_master_lag_setup(struct net_device *lag_dev, struct dsa_port *cpu_dp, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack); +void dsa_master_lag_teardown(struct net_device *lag_dev, + struct dsa_port *cpu_dp); + +#endif diff --git a/net/dsa/netlink.c b/net/dsa/netlink.c index ecf9ed1de185..bd4bbaf851de 100644 --- a/net/dsa/netlink.c +++ b/net/dsa/netlink.c @@ -4,7 +4,8 @@ #include <linux/netdevice.h> #include <net/rtnetlink.h> -#include "dsa_priv.h" +#include "netlink.h" +#include "slave.h" static const struct nla_policy dsa_policy[IFLA_DSA_MAX + 1] = { [IFLA_DSA_MASTER] = { .type = NLA_U32 }, diff --git a/net/dsa/netlink.h b/net/dsa/netlink.h new file mode 100644 index 000000000000..7eda2fa15722 --- /dev/null +++ b/net/dsa/netlink.h @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_NETLINK_H +#define __DSA_NETLINK_H + +extern struct rtnl_link_ops dsa_link_ops __read_mostly; + +#endif diff --git a/net/dsa/port.c b/net/dsa/port.c index 208168276995..67ad1adec2a2 100644 --- a/net/dsa/port.c +++ b/net/dsa/port.c @@ -12,7 +12,11 @@ #include <linux/of_mdio.h> #include <linux/of_net.h> -#include "dsa_priv.h" +#include "dsa.h" +#include "port.h" +#include "slave.h" +#include "switch.h" +#include "tag_8021q.h" /** * dsa_port_notify - Notify the switching fabric of changes to a port @@ -110,6 +114,22 @@ static bool dsa_port_can_configure_learning(struct dsa_port *dp) return !err; } +bool dsa_port_supports_hwtstamp(struct dsa_port *dp, struct ifreq *ifr) +{ + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_hwtstamp_get || !ds->ops->port_hwtstamp_set) + return false; + + /* "See through" shim implementations of the "get" method. + * This will clobber the ifreq structure, but we will either return an + * error, or the master will overwrite it with proper values. + */ + err = ds->ops->port_hwtstamp_get(ds, dp->index, ifr); + return err != -EOPNOTSUPP; +} + int dsa_port_set_state(struct dsa_port *dp, u8 state, bool do_fast_age) { struct dsa_switch *ds = dp->ds; @@ -1536,16 +1556,14 @@ static void dsa_port_phylink_validate(struct phylink_config *config, unsigned long *supported, struct phylink_link_state *state) { - struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); - struct dsa_switch *ds = dp->ds; - - if (!ds->ops->phylink_validate) { - if (config->mac_capabilities) - phylink_generic_validate(config, supported, state); - return; - } - - ds->ops->phylink_validate(ds, dp->index, supported, state); + /* Skip call for drivers which don't yet set mac_capabilities, + * since validating in that case would mean their PHY will advertise + * nothing. In turn, skipping validation makes them advertise + * everything that the PHY supports, so those drivers should be + * converted ASAP. + */ + if (config->mac_capabilities) + phylink_generic_validate(config, supported, state); } static void dsa_port_phylink_mac_pcs_get_state(struct phylink_config *config, diff --git a/net/dsa/port.h b/net/dsa/port.h new file mode 100644 index 000000000000..9c218660d223 --- /dev/null +++ b/net/dsa/port.h @@ -0,0 +1,114 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_PORT_H +#define __DSA_PORT_H + +#include <linux/types.h> +#include <net/dsa.h> + +struct ifreq; +struct netdev_lag_lower_state_info; +struct netdev_lag_upper_info; +struct netlink_ext_ack; +struct switchdev_mst_state; +struct switchdev_obj_port_mdb; +struct switchdev_vlan_msti; +struct phy_device; + +bool dsa_port_supports_hwtstamp(struct dsa_port *dp, struct ifreq *ifr); +void dsa_port_set_tag_protocol(struct dsa_port *cpu_dp, + const struct dsa_device_ops *tag_ops); +int dsa_port_set_state(struct dsa_port *dp, u8 state, bool do_fast_age); +int dsa_port_set_mst_state(struct dsa_port *dp, + const struct switchdev_mst_state *state, + struct netlink_ext_ack *extack); +int dsa_port_enable_rt(struct dsa_port *dp, struct phy_device *phy); +int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy); +void dsa_port_disable_rt(struct dsa_port *dp); +void dsa_port_disable(struct dsa_port *dp); +int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br, + struct netlink_ext_ack *extack); +void dsa_port_pre_bridge_leave(struct dsa_port *dp, struct net_device *br); +void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br); +int dsa_port_lag_change(struct dsa_port *dp, + struct netdev_lag_lower_state_info *linfo); +int dsa_port_lag_join(struct dsa_port *dp, struct net_device *lag_dev, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack); +void dsa_port_pre_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); +void dsa_port_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); +int dsa_port_vlan_filtering(struct dsa_port *dp, bool vlan_filtering, + struct netlink_ext_ack *extack); +bool dsa_port_skip_vlan_configuration(struct dsa_port *dp); +int dsa_port_ageing_time(struct dsa_port *dp, clock_t ageing_clock); +int dsa_port_mst_enable(struct dsa_port *dp, bool on, + struct netlink_ext_ack *extack); +int dsa_port_vlan_msti(struct dsa_port *dp, + const struct switchdev_vlan_msti *msti); +int dsa_port_mtu_change(struct dsa_port *dp, int new_mtu); +int dsa_port_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_standalone_host_fdb_add(struct dsa_port *dp, + const unsigned char *addr, u16 vid); +int dsa_port_standalone_host_fdb_del(struct dsa_port *dp, + const unsigned char *addr, u16 vid); +int dsa_port_bridge_host_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_bridge_host_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_lag_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_lag_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_fdb_dump(struct dsa_port *dp, dsa_fdb_dump_cb_t *cb, void *data); +int dsa_port_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_standalone_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_standalone_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_bridge_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_bridge_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_pre_bridge_flags(const struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack); +int dsa_port_bridge_flags(struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack); +int dsa_port_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack); +int dsa_port_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan); +int dsa_port_host_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack); +int dsa_port_host_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan); +int dsa_port_mrp_add(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp); +int dsa_port_mrp_del(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp); +int dsa_port_mrp_add_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp); +int dsa_port_mrp_del_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp); +int dsa_port_phylink_create(struct dsa_port *dp); +void dsa_port_phylink_destroy(struct dsa_port *dp); +int dsa_shared_port_link_register_of(struct dsa_port *dp); +void dsa_shared_port_link_unregister_of(struct dsa_port *dp); +int dsa_port_hsr_join(struct dsa_port *dp, struct net_device *hsr); +void dsa_port_hsr_leave(struct dsa_port *dp, struct net_device *hsr); +int dsa_port_tag_8021q_vlan_add(struct dsa_port *dp, u16 vid, bool broadcast); +void dsa_port_tag_8021q_vlan_del(struct dsa_port *dp, u16 vid, bool broadcast); +void dsa_port_set_host_flood(struct dsa_port *dp, bool uc, bool mc); +int dsa_port_change_master(struct dsa_port *dp, struct net_device *master, + struct netlink_ext_ack *extack); + +#endif diff --git a/net/dsa/slave.c b/net/dsa/slave.c index a9fde48cffd4..aab79c355224 100644 --- a/net/dsa/slave.c +++ b/net/dsa/slave.c @@ -22,7 +22,54 @@ #include <net/dcbnl.h> #include <linux/netpoll.h> -#include "dsa_priv.h" +#include "dsa.h" +#include "port.h" +#include "master.h" +#include "netlink.h" +#include "slave.h" +#include "tag.h" + +struct dsa_switchdev_event_work { + struct net_device *dev; + struct net_device *orig_dev; + struct work_struct work; + unsigned long event; + /* Specific for SWITCHDEV_FDB_ADD_TO_DEVICE and + * SWITCHDEV_FDB_DEL_TO_DEVICE + */ + unsigned char addr[ETH_ALEN]; + u16 vid; + bool host_addr; +}; + +enum dsa_standalone_event { + DSA_UC_ADD, + DSA_UC_DEL, + DSA_MC_ADD, + DSA_MC_DEL, +}; + +struct dsa_standalone_event_work { + struct work_struct work; + struct net_device *dev; + enum dsa_standalone_event event; + unsigned char addr[ETH_ALEN]; + u16 vid; +}; + +static bool dsa_switch_supports_uc_filtering(struct dsa_switch *ds) +{ + return ds->ops->port_fdb_add && ds->ops->port_fdb_del && + ds->fdb_isolation && !ds->vlan_filtering_is_global && + !ds->needs_standalone_vlan_filtering; +} + +static bool dsa_switch_supports_mc_filtering(struct dsa_switch *ds) +{ + return ds->ops->port_mdb_add && ds->ops->port_mdb_del && + ds->fdb_isolation && !ds->vlan_filtering_is_global && + !ds->needs_standalone_vlan_filtering; +} static void dsa_slave_standalone_event_work(struct work_struct *work) { @@ -976,12 +1023,12 @@ static void dsa_slave_get_ethtool_stats(struct net_device *dev, s = per_cpu_ptr(dev->tstats, i); do { - start = u64_stats_fetch_begin_irq(&s->syncp); + start = u64_stats_fetch_begin(&s->syncp); tx_packets = u64_stats_read(&s->tx_packets); tx_bytes = u64_stats_read(&s->tx_bytes); rx_packets = u64_stats_read(&s->rx_packets); rx_bytes = u64_stats_read(&s->rx_bytes); - } while (u64_stats_fetch_retry_irq(&s->syncp, start)); + } while (u64_stats_fetch_retry(&s->syncp, start)); data[0] += tx_packets; data[1] += tx_bytes; data[2] += rx_packets; @@ -2165,13 +2212,6 @@ static const struct dcbnl_rtnl_ops __maybe_unused dsa_slave_dcbnl_ops = { .ieee_delapp = dsa_slave_dcbnl_ieee_delapp, }; -static struct devlink_port *dsa_slave_get_devlink_port(struct net_device *dev) -{ - struct dsa_port *dp = dsa_slave_to_port(dev); - - return &dp->devlink_port; -} - static void dsa_slave_get_stats64(struct net_device *dev, struct rtnl_link_stats64 *s) { @@ -2219,7 +2259,6 @@ static const struct net_device_ops dsa_slave_netdev_ops = { .ndo_get_stats64 = dsa_slave_get_stats64, .ndo_vlan_rx_add_vid = dsa_slave_vlan_rx_add_vid, .ndo_vlan_rx_kill_vid = dsa_slave_vlan_rx_kill_vid, - .ndo_get_devlink_port = dsa_slave_get_devlink_port, .ndo_change_mtu = dsa_slave_change_mtu, .ndo_fill_forward_path = dsa_slave_fill_forward_path, }; @@ -2374,16 +2413,25 @@ int dsa_slave_create(struct dsa_port *port) { struct net_device *master = dsa_port_to_master(port); struct dsa_switch *ds = port->ds; - const char *name = port->name; struct net_device *slave_dev; struct dsa_slave_priv *p; + const char *name; + int assign_type; int ret; if (!ds->num_tx_queues) ds->num_tx_queues = 1; + if (port->name) { + name = port->name; + assign_type = NET_NAME_PREDICTABLE; + } else { + name = "eth%d"; + assign_type = NET_NAME_ENUM; + } + slave_dev = alloc_netdev_mqs(sizeof(struct dsa_slave_priv), name, - NET_NAME_UNKNOWN, ether_setup, + assign_type, ether_setup, ds->num_tx_queues, 1); if (slave_dev == NULL) return -ENOMEM; @@ -2406,6 +2454,7 @@ int dsa_slave_create(struct dsa_port *port) SET_NETDEV_DEVTYPE(slave_dev, &dsa_type); SET_NETDEV_DEV(slave_dev, port->ds->dev); + SET_NETDEV_DEVLINK_PORT(slave_dev, &port->devlink_port); slave_dev->dev.of_node = port->dn; slave_dev->vlan_features = master->vlan_features; diff --git a/net/dsa/slave.h b/net/dsa/slave.h new file mode 100644 index 000000000000..d0abe609e00d --- /dev/null +++ b/net/dsa/slave.h @@ -0,0 +1,69 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_SLAVE_H +#define __DSA_SLAVE_H + +#include <linux/if_bridge.h> +#include <linux/if_vlan.h> +#include <linux/list.h> +#include <linux/netpoll.h> +#include <linux/types.h> +#include <net/dsa.h> +#include <net/gro_cells.h> + +struct net_device; +struct netlink_ext_ack; + +extern struct notifier_block dsa_slave_switchdev_notifier; +extern struct notifier_block dsa_slave_switchdev_blocking_notifier; + +struct dsa_slave_priv { + /* Copy of CPU port xmit for faster access in slave transmit hot path */ + struct sk_buff * (*xmit)(struct sk_buff *skb, + struct net_device *dev); + + struct gro_cells gcells; + + /* DSA port data, such as switch, port index, etc. */ + struct dsa_port *dp; + +#ifdef CONFIG_NET_POLL_CONTROLLER + struct netpoll *netpoll; +#endif + + /* TC context */ + struct list_head mall_tc_list; +}; + +void dsa_slave_mii_bus_init(struct dsa_switch *ds); +int dsa_slave_create(struct dsa_port *dp); +void dsa_slave_destroy(struct net_device *slave_dev); +int dsa_slave_suspend(struct net_device *slave_dev); +int dsa_slave_resume(struct net_device *slave_dev); +int dsa_slave_register_notifier(void); +void dsa_slave_unregister_notifier(void); +void dsa_slave_sync_ha(struct net_device *dev); +void dsa_slave_unsync_ha(struct net_device *dev); +void dsa_slave_setup_tagger(struct net_device *slave); +int dsa_slave_change_mtu(struct net_device *dev, int new_mtu); +int dsa_slave_change_master(struct net_device *dev, struct net_device *master, + struct netlink_ext_ack *extack); +int dsa_slave_manage_vlan_filtering(struct net_device *dev, + bool vlan_filtering); + +static inline struct dsa_port *dsa_slave_to_port(const struct net_device *dev) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + + return p->dp; +} + +static inline struct net_device * +dsa_slave_to_master(const struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return dsa_port_to_master(dp); +} + +#endif diff --git a/net/dsa/switch.c b/net/dsa/switch.c index ce56acdba203..d5bc4bb7310d 100644 --- a/net/dsa/switch.c +++ b/net/dsa/switch.c @@ -12,7 +12,12 @@ #include <linux/if_vlan.h> #include <net/switchdev.h> -#include "dsa_priv.h" +#include "dsa.h" +#include "netlink.h" +#include "port.h" +#include "slave.h" +#include "switch.h" +#include "tag_8021q.h" static unsigned int dsa_switch_fastest_ageing_time(struct dsa_switch *ds, unsigned int ageing_time) @@ -1013,6 +1018,52 @@ static int dsa_switch_event(struct notifier_block *nb, return notifier_from_errno(err); } +/** + * dsa_tree_notify - Execute code for all switches in a DSA switch tree. + * @dst: collection of struct dsa_switch devices to notify. + * @e: event, must be of type DSA_NOTIFIER_* + * @v: event-specific value. + * + * Given a struct dsa_switch_tree, this can be used to run a function once for + * each member DSA switch. The other alternative of traversing the tree is only + * through its ports list, which does not uniquely list the switches. + */ +int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v) +{ + struct raw_notifier_head *nh = &dst->nh; + int err; + + err = raw_notifier_call_chain(nh, e, v); + + return notifier_to_errno(err); +} + +/** + * dsa_broadcast - Notify all DSA trees in the system. + * @e: event, must be of type DSA_NOTIFIER_* + * @v: event-specific value. + * + * Can be used to notify the switching fabric of events such as cross-chip + * bridging between disjoint trees (such as islands of tagger-compatible + * switches bridged by an incompatible middle switch). + * + * WARNING: this function is not reliable during probe time, because probing + * between trees is asynchronous and not all DSA trees might have probed. + */ +int dsa_broadcast(unsigned long e, void *v) +{ + struct dsa_switch_tree *dst; + int err = 0; + + list_for_each_entry(dst, &dsa_tree_list, list) { + err = dsa_tree_notify(dst, e, v); + if (err) + break; + } + + return err; +} + int dsa_switch_register_notifier(struct dsa_switch *ds) { ds->nb.notifier_call = dsa_switch_event; diff --git a/net/dsa/switch.h b/net/dsa/switch.h new file mode 100644 index 000000000000..15e67b95eb6e --- /dev/null +++ b/net/dsa/switch.h @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_SWITCH_H +#define __DSA_SWITCH_H + +#include <net/dsa.h> + +struct netlink_ext_ack; + +enum { + DSA_NOTIFIER_AGEING_TIME, + DSA_NOTIFIER_BRIDGE_JOIN, + DSA_NOTIFIER_BRIDGE_LEAVE, + DSA_NOTIFIER_FDB_ADD, + DSA_NOTIFIER_FDB_DEL, + DSA_NOTIFIER_HOST_FDB_ADD, + DSA_NOTIFIER_HOST_FDB_DEL, + DSA_NOTIFIER_LAG_FDB_ADD, + DSA_NOTIFIER_LAG_FDB_DEL, + DSA_NOTIFIER_LAG_CHANGE, + DSA_NOTIFIER_LAG_JOIN, + DSA_NOTIFIER_LAG_LEAVE, + DSA_NOTIFIER_MDB_ADD, + DSA_NOTIFIER_MDB_DEL, + DSA_NOTIFIER_HOST_MDB_ADD, + DSA_NOTIFIER_HOST_MDB_DEL, + DSA_NOTIFIER_VLAN_ADD, + DSA_NOTIFIER_VLAN_DEL, + DSA_NOTIFIER_HOST_VLAN_ADD, + DSA_NOTIFIER_HOST_VLAN_DEL, + DSA_NOTIFIER_MTU, + DSA_NOTIFIER_TAG_PROTO, + DSA_NOTIFIER_TAG_PROTO_CONNECT, + DSA_NOTIFIER_TAG_PROTO_DISCONNECT, + DSA_NOTIFIER_TAG_8021Q_VLAN_ADD, + DSA_NOTIFIER_TAG_8021Q_VLAN_DEL, + DSA_NOTIFIER_MASTER_STATE_CHANGE, +}; + +/* DSA_NOTIFIER_AGEING_TIME */ +struct dsa_notifier_ageing_time_info { + unsigned int ageing_time; +}; + +/* DSA_NOTIFIER_BRIDGE_* */ +struct dsa_notifier_bridge_info { + const struct dsa_port *dp; + struct dsa_bridge bridge; + bool tx_fwd_offload; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_FDB_* */ +struct dsa_notifier_fdb_info { + const struct dsa_port *dp; + const unsigned char *addr; + u16 vid; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_LAG_FDB_* */ +struct dsa_notifier_lag_fdb_info { + struct dsa_lag *lag; + const unsigned char *addr; + u16 vid; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_MDB_* */ +struct dsa_notifier_mdb_info { + const struct dsa_port *dp; + const struct switchdev_obj_port_mdb *mdb; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_LAG_* */ +struct dsa_notifier_lag_info { + const struct dsa_port *dp; + struct dsa_lag lag; + struct netdev_lag_upper_info *info; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_VLAN_* */ +struct dsa_notifier_vlan_info { + const struct dsa_port *dp; + const struct switchdev_obj_port_vlan *vlan; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_MTU */ +struct dsa_notifier_mtu_info { + const struct dsa_port *dp; + int mtu; +}; + +/* DSA_NOTIFIER_TAG_PROTO_* */ +struct dsa_notifier_tag_proto_info { + const struct dsa_device_ops *tag_ops; +}; + +/* DSA_NOTIFIER_TAG_8021Q_VLAN_* */ +struct dsa_notifier_tag_8021q_vlan_info { + const struct dsa_port *dp; + u16 vid; +}; + +/* DSA_NOTIFIER_MASTER_STATE_CHANGE */ +struct dsa_notifier_master_state_info { + const struct net_device *master; + bool operational; +}; + +int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v); +int dsa_broadcast(unsigned long e, void *v); + +int dsa_switch_register_notifier(struct dsa_switch *ds); +void dsa_switch_unregister_notifier(struct dsa_switch *ds); + +#endif diff --git a/net/dsa/tag.c b/net/dsa/tag.c new file mode 100644 index 000000000000..b2fba1a003ce --- /dev/null +++ b/net/dsa/tag.c @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * DSA tagging protocol handling + * + * Copyright (c) 2008-2009 Marvell Semiconductor + * Copyright (c) 2013 Florian Fainelli <florian@openwrt.org> + * Copyright (c) 2016 Andrew Lunn <andrew@lunn.ch> + */ + +#include <linux/netdevice.h> +#include <linux/ptp_classify.h> +#include <linux/skbuff.h> +#include <net/dsa.h> +#include <net/dst_metadata.h> + +#include "slave.h" +#include "tag.h" + +static LIST_HEAD(dsa_tag_drivers_list); +static DEFINE_MUTEX(dsa_tag_drivers_lock); + +/* Determine if we should defer delivery of skb until we have a rx timestamp. + * + * Called from dsa_switch_rcv. For now, this will only work if tagging is + * enabled on the switch. Normally the MAC driver would retrieve the hardware + * timestamp when it reads the packet out of the hardware. However in a DSA + * switch, the DSA driver owning the interface to which the packet is + * delivered is never notified unless we do so here. + */ +static bool dsa_skb_defer_rx_timestamp(struct dsa_slave_priv *p, + struct sk_buff *skb) +{ + struct dsa_switch *ds = p->dp->ds; + unsigned int type; + + if (!ds->ops->port_rxtstamp) + return false; + + if (skb_headroom(skb) < ETH_HLEN) + return false; + + __skb_push(skb, ETH_HLEN); + + type = ptp_classify_raw(skb); + + __skb_pull(skb, ETH_HLEN); + + if (type == PTP_CLASS_NONE) + return false; + + return ds->ops->port_rxtstamp(ds, p->dp->index, skb, type); +} + +static int dsa_switch_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *unused) +{ + struct metadata_dst *md_dst = skb_metadata_dst(skb); + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct sk_buff *nskb = NULL; + struct dsa_slave_priv *p; + + if (unlikely(!cpu_dp)) { + kfree_skb(skb); + return 0; + } + + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + return 0; + + if (md_dst && md_dst->type == METADATA_HW_PORT_MUX) { + unsigned int port = md_dst->u.port_info.port_id; + + skb_dst_drop(skb); + if (!skb_has_extensions(skb)) + skb->slow_gro = 0; + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (likely(skb->dev)) { + dsa_default_offload_fwd_mark(skb); + nskb = skb; + } + } else { + nskb = cpu_dp->rcv(skb, dev); + } + + if (!nskb) { + kfree_skb(skb); + return 0; + } + + skb = nskb; + skb_push(skb, ETH_HLEN); + skb->pkt_type = PACKET_HOST; + skb->protocol = eth_type_trans(skb, skb->dev); + + if (unlikely(!dsa_slave_dev_check(skb->dev))) { + /* Packet is to be injected directly on an upper + * device, e.g. a team/bond, so skip all DSA-port + * specific actions. + */ + netif_rx(skb); + return 0; + } + + p = netdev_priv(skb->dev); + + if (unlikely(cpu_dp->ds->untag_bridge_pvid)) { + nskb = dsa_untag_bridge_pvid(skb); + if (!nskb) { + kfree_skb(skb); + return 0; + } + skb = nskb; + } + + dev_sw_netstats_rx_add(skb->dev, skb->len); + + if (dsa_skb_defer_rx_timestamp(p, skb)) + return 0; + + gro_cells_receive(&p->gcells, skb); + + return 0; +} + +struct packet_type dsa_pack_type __read_mostly = { + .type = cpu_to_be16(ETH_P_XDSA), + .func = dsa_switch_rcv, +}; + +static void dsa_tag_driver_register(struct dsa_tag_driver *dsa_tag_driver, + struct module *owner) +{ + dsa_tag_driver->owner = owner; + + mutex_lock(&dsa_tag_drivers_lock); + list_add_tail(&dsa_tag_driver->list, &dsa_tag_drivers_list); + mutex_unlock(&dsa_tag_drivers_lock); +} + +void dsa_tag_drivers_register(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count, struct module *owner) +{ + unsigned int i; + + for (i = 0; i < count; i++) + dsa_tag_driver_register(dsa_tag_driver_array[i], owner); +} + +static void dsa_tag_driver_unregister(struct dsa_tag_driver *dsa_tag_driver) +{ + mutex_lock(&dsa_tag_drivers_lock); + list_del(&dsa_tag_driver->list); + mutex_unlock(&dsa_tag_drivers_lock); +} +EXPORT_SYMBOL_GPL(dsa_tag_drivers_register); + +void dsa_tag_drivers_unregister(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count) +{ + unsigned int i; + + for (i = 0; i < count; i++) + dsa_tag_driver_unregister(dsa_tag_driver_array[i]); +} +EXPORT_SYMBOL_GPL(dsa_tag_drivers_unregister); + +const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops) +{ + return ops->name; +}; + +/* Function takes a reference on the module owning the tagger, + * so dsa_tag_driver_put must be called afterwards. + */ +const struct dsa_device_ops *dsa_tag_driver_get_by_name(const char *name) +{ + const struct dsa_device_ops *ops = ERR_PTR(-ENOPROTOOPT); + struct dsa_tag_driver *dsa_tag_driver; + + request_module("%s%s", DSA_TAG_DRIVER_ALIAS, name); + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + const struct dsa_device_ops *tmp = dsa_tag_driver->ops; + + if (strcmp(name, tmp->name)) + continue; + + if (!try_module_get(dsa_tag_driver->owner)) + break; + + ops = tmp; + break; + } + mutex_unlock(&dsa_tag_drivers_lock); + + return ops; +} + +const struct dsa_device_ops *dsa_tag_driver_get_by_id(int tag_protocol) +{ + struct dsa_tag_driver *dsa_tag_driver; + const struct dsa_device_ops *ops; + bool found = false; + + request_module("%sid-%d", DSA_TAG_DRIVER_ALIAS, tag_protocol); + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + ops = dsa_tag_driver->ops; + if (ops->proto == tag_protocol) { + found = true; + break; + } + } + + if (found) { + if (!try_module_get(dsa_tag_driver->owner)) + ops = ERR_PTR(-ENOPROTOOPT); + } else { + ops = ERR_PTR(-ENOPROTOOPT); + } + + mutex_unlock(&dsa_tag_drivers_lock); + + return ops; +} + +void dsa_tag_driver_put(const struct dsa_device_ops *ops) +{ + struct dsa_tag_driver *dsa_tag_driver; + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + if (dsa_tag_driver->ops == ops) { + module_put(dsa_tag_driver->owner); + break; + } + } + mutex_unlock(&dsa_tag_drivers_lock); +} diff --git a/net/dsa/tag.h b/net/dsa/tag.h new file mode 100644 index 000000000000..7cfbca824f1c --- /dev/null +++ b/net/dsa/tag.h @@ -0,0 +1,310 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_TAG_H +#define __DSA_TAG_H + +#include <linux/if_vlan.h> +#include <linux/list.h> +#include <linux/types.h> +#include <net/dsa.h> + +#include "port.h" +#include "slave.h" + +struct dsa_tag_driver { + const struct dsa_device_ops *ops; + struct list_head list; + struct module *owner; +}; + +extern struct packet_type dsa_pack_type; + +const struct dsa_device_ops *dsa_tag_driver_get_by_id(int tag_protocol); +const struct dsa_device_ops *dsa_tag_driver_get_by_name(const char *name); +void dsa_tag_driver_put(const struct dsa_device_ops *ops); +const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops); + +static inline int dsa_tag_protocol_overhead(const struct dsa_device_ops *ops) +{ + return ops->needed_headroom + ops->needed_tailroom; +} + +static inline struct net_device *dsa_master_find_slave(struct net_device *dev, + int device, int port) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dp->ds->index == device && dp->index == port && + dp->type == DSA_PORT_TYPE_USER) + return dp->slave; + + return NULL; +} + +/* If under a bridge with vlan_filtering=0, make sure to send pvid-tagged + * frames as untagged, since the bridge will not untag them. + */ +static inline struct sk_buff *dsa_untag_bridge_pvid(struct sk_buff *skb) +{ + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct net_device *dev = skb->dev; + struct net_device *upper_dev; + u16 vid, pvid, proto; + int err; + + if (!br || br_vlan_enabled(br)) + return skb; + + err = br_vlan_get_proto(br, &proto); + if (err) + return skb; + + /* Move VLAN tag from data to hwaccel */ + if (!skb_vlan_tag_present(skb) && skb->protocol == htons(proto)) { + skb = skb_vlan_untag(skb); + if (!skb) + return NULL; + } + + if (!skb_vlan_tag_present(skb)) + return skb; + + vid = skb_vlan_tag_get_id(skb); + + /* We already run under an RCU read-side critical section since + * we are called from netif_receive_skb_list_internal(). + */ + err = br_vlan_get_pvid_rcu(dev, &pvid); + if (err) + return skb; + + if (vid != pvid) + return skb; + + /* The sad part about attempting to untag from DSA is that we + * don't know, unless we check, if the skb will end up in + * the bridge's data path - br_allowed_ingress() - or not. + * For example, there might be an 8021q upper for the + * default_pvid of the bridge, which will steal VLAN-tagged traffic + * from the bridge's data path. This is a configuration that DSA + * supports because vlan_filtering is 0. In that case, we should + * definitely keep the tag, to make sure it keeps working. + */ + upper_dev = __vlan_find_dev_deep_rcu(br, htons(proto), vid); + if (upper_dev) + return skb; + + __vlan_hwaccel_clear_tag(skb); + + return skb; +} + +/* For switches without hardware support for DSA tagging to be able + * to support termination through the bridge. + */ +static inline struct net_device * +dsa_find_designated_bridge_port_by_vid(struct net_device *master, u16 vid) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct bridge_vlan_info vinfo; + struct net_device *slave; + struct dsa_port *dp; + int err; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->type != DSA_PORT_TYPE_USER) + continue; + + if (!dp->bridge) + continue; + + if (dp->stp_state != BR_STATE_LEARNING && + dp->stp_state != BR_STATE_FORWARDING) + continue; + + /* Since the bridge might learn this packet, keep the CPU port + * affinity with the port that will be used for the reply on + * xmit. + */ + if (dp->cpu_dp != cpu_dp) + continue; + + slave = dp->slave; + + err = br_vlan_get_info_rcu(slave, vid, &vinfo); + if (err) + continue; + + return slave; + } + + return NULL; +} + +/* If the ingress port offloads the bridge, we mark the frame as autonomously + * forwarded by hardware, so the software bridge doesn't forward in twice, back + * to us, because we already did. However, if we're in fallback mode and we do + * software bridging, we are not offloading it, therefore the dp->bridge + * pointer is not populated, and flooding needs to be done by software (we are + * effectively operating in standalone ports mode). + */ +static inline void dsa_default_offload_fwd_mark(struct sk_buff *skb) +{ + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + + skb->offload_fwd_mark = !!(dp->bridge); +} + +/* Helper for removing DSA header tags from packets in the RX path. + * Must not be called before skb_pull(len). + * skb->data + * | + * v + * | | | | | | | | | | | | | | | | | | | + * +-----------------------+-----------------------+---------------+-------+ + * | Destination MAC | Source MAC | DSA header | EType | + * +-----------------------+-----------------------+---------------+-------+ + * | | + * <----- len -----> <----- len -----> + * | + * >>>>>>> v + * >>>>>>> | | | | | | | | | | | | | | | + * >>>>>>> +-----------------------+-----------------------+-------+ + * >>>>>>> | Destination MAC | Source MAC | EType | + * +-----------------------+-----------------------+-------+ + * ^ + * | + * skb->data + */ +static inline void dsa_strip_etype_header(struct sk_buff *skb, int len) +{ + memmove(skb->data - ETH_HLEN, skb->data - ETH_HLEN - len, 2 * ETH_ALEN); +} + +/* Helper for creating space for DSA header tags in TX path packets. + * Must not be called before skb_push(len). + * + * Before: + * + * <<<<<<< | | | | | | | | | | | | | | | + * ^ <<<<<<< +-----------------------+-----------------------+-------+ + * | <<<<<<< | Destination MAC | Source MAC | EType | + * | +-----------------------+-----------------------+-------+ + * <----- len -----> + * | + * | + * skb->data + * + * After: + * + * | | | | | | | | | | | | | | | | | | | + * +-----------------------+-----------------------+---------------+-------+ + * | Destination MAC | Source MAC | DSA header | EType | + * +-----------------------+-----------------------+---------------+-------+ + * ^ | | + * | <----- len -----> + * skb->data + */ +static inline void dsa_alloc_etype_header(struct sk_buff *skb, int len) +{ + memmove(skb->data, skb->data + len, 2 * ETH_ALEN); +} + +/* On RX, eth_type_trans() on the DSA master pulls ETH_HLEN bytes starting from + * skb_mac_header(skb), which leaves skb->data pointing at the first byte after + * what the DSA master perceives as the EtherType (the beginning of the L3 + * protocol). Since DSA EtherType header taggers treat the EtherType as part of + * the DSA tag itself, and the EtherType is 2 bytes in length, the DSA header + * is located 2 bytes behind skb->data. Note that EtherType in this context + * means the first 2 bytes of the DSA header, not the encapsulated EtherType + * that will become visible after the DSA header is stripped. + */ +static inline void *dsa_etype_header_pos_rx(struct sk_buff *skb) +{ + return skb->data - 2; +} + +/* On TX, skb->data points to skb_mac_header(skb), which means that EtherType + * header taggers start exactly where the EtherType is (the EtherType is + * treated as part of the DSA header). + */ +static inline void *dsa_etype_header_pos_tx(struct sk_buff *skb) +{ + return skb->data + 2 * ETH_ALEN; +} + +/* Create 2 modaliases per tagging protocol, one to auto-load the module + * given the ID reported by get_tag_protocol(), and the other by name. + */ +#define DSA_TAG_DRIVER_ALIAS "dsa_tag:" +#define MODULE_ALIAS_DSA_TAG_DRIVER(__proto, __name) \ + MODULE_ALIAS(DSA_TAG_DRIVER_ALIAS __name); \ + MODULE_ALIAS(DSA_TAG_DRIVER_ALIAS "id-" \ + __stringify(__proto##_VALUE)) + +void dsa_tag_drivers_register(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count, + struct module *owner); +void dsa_tag_drivers_unregister(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count); + +#define dsa_tag_driver_module_drivers(__dsa_tag_drivers_array, __count) \ +static int __init dsa_tag_driver_module_init(void) \ +{ \ + dsa_tag_drivers_register(__dsa_tag_drivers_array, __count, \ + THIS_MODULE); \ + return 0; \ +} \ +module_init(dsa_tag_driver_module_init); \ + \ +static void __exit dsa_tag_driver_module_exit(void) \ +{ \ + dsa_tag_drivers_unregister(__dsa_tag_drivers_array, __count); \ +} \ +module_exit(dsa_tag_driver_module_exit) + +/** + * module_dsa_tag_drivers() - Helper macro for registering DSA tag + * drivers + * @__ops_array: Array of tag driver structures + * + * Helper macro for DSA tag drivers which do not do anything special + * in module init/exit. Each module may only use this macro once, and + * calling it replaces module_init() and module_exit(). + */ +#define module_dsa_tag_drivers(__ops_array) \ +dsa_tag_driver_module_drivers(__ops_array, ARRAY_SIZE(__ops_array)) + +#define DSA_TAG_DRIVER_NAME(__ops) dsa_tag_driver ## _ ## __ops + +/* Create a static structure we can build a linked list of dsa_tag + * drivers + */ +#define DSA_TAG_DRIVER(__ops) \ +static struct dsa_tag_driver DSA_TAG_DRIVER_NAME(__ops) = { \ + .ops = &__ops, \ +} + +/** + * module_dsa_tag_driver() - Helper macro for registering a single DSA tag + * driver + * @__ops: Single tag driver structures + * + * Helper macro for DSA tag drivers which do not do anything special + * in module init/exit. Each module may only use this macro once, and + * calling it replaces module_init() and module_exit(). + */ +#define module_dsa_tag_driver(__ops) \ +DSA_TAG_DRIVER(__ops); \ + \ +static struct dsa_tag_driver *dsa_tag_driver_array[] = { \ + &DSA_TAG_DRIVER_NAME(__ops) \ +}; \ +module_dsa_tag_drivers(dsa_tag_driver_array) + +#endif diff --git a/net/dsa/tag_8021q.c b/net/dsa/tag_8021q.c index 34e5ec5d3e23..5ee9ef00954e 100644 --- a/net/dsa/tag_8021q.c +++ b/net/dsa/tag_8021q.c @@ -7,7 +7,10 @@ #include <linux/if_vlan.h> #include <linux/dsa/8021q.h> -#include "dsa_priv.h" +#include "port.h" +#include "switch.h" +#include "tag.h" +#include "tag_8021q.h" /* Binary structure of the fake 12-bit VID field (when the TPID is * ETH_P_DSA_8021Q): @@ -60,6 +63,20 @@ #define DSA_8021Q_PORT(x) (((x) << DSA_8021Q_PORT_SHIFT) & \ DSA_8021Q_PORT_MASK) +struct dsa_tag_8021q_vlan { + struct list_head list; + int port; + u16 vid; + refcount_t refcount; +}; + +struct dsa_8021q_context { + struct dsa_switch *ds; + struct list_head vlans; + /* EtherType of RX VID, used for filtering on master interface */ + __be16 proto; +}; + u16 dsa_tag_8021q_bridge_vid(unsigned int bridge_num) { /* The VBID value of 0 is reserved for precise TX, but it is also @@ -398,6 +415,7 @@ static void dsa_tag_8021q_teardown(struct dsa_switch *ds) int dsa_tag_8021q_register(struct dsa_switch *ds, __be16 proto) { struct dsa_8021q_context *ctx; + int err; ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); if (!ctx) @@ -410,7 +428,15 @@ int dsa_tag_8021q_register(struct dsa_switch *ds, __be16 proto) ds->tag_8021q_ctx = ctx; - return dsa_tag_8021q_setup(ds); + err = dsa_tag_8021q_setup(ds); + if (err) + goto err_free; + + return 0; + +err_free: + kfree(ctx); + return err; } EXPORT_SYMBOL_GPL(dsa_tag_8021q_register); diff --git a/net/dsa/tag_8021q.h b/net/dsa/tag_8021q.h new file mode 100644 index 000000000000..b75cbaa028ef --- /dev/null +++ b/net/dsa/tag_8021q.h @@ -0,0 +1,27 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef __DSA_TAG_8021Q_H +#define __DSA_TAG_8021Q_H + +#include <net/dsa.h> + +#include "switch.h" + +struct sk_buff; +struct net_device; + +struct sk_buff *dsa_8021q_xmit(struct sk_buff *skb, struct net_device *netdev, + u16 tpid, u16 tci); + +void dsa_8021q_rcv(struct sk_buff *skb, int *source_port, int *switch_id, + int *vbid); + +struct net_device *dsa_tag_8021q_find_port_by_vbid(struct net_device *master, + int vbid); + +int dsa_switch_tag_8021q_vlan_add(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info); +int dsa_switch_tag_8021q_vlan_del(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info); + +#endif diff --git a/net/dsa/tag_ar9331.c b/net/dsa/tag_ar9331.c index 8a02ac44282f..7f3b7d730b85 100644 --- a/net/dsa/tag_ar9331.c +++ b/net/dsa/tag_ar9331.c @@ -7,7 +7,9 @@ #include <linux/bitfield.h> #include <linux/etherdevice.h> -#include "dsa_priv.h" +#include "tag.h" + +#define AR9331_NAME "ar9331" #define AR9331_HDR_LEN 2 #define AR9331_HDR_VERSION 1 @@ -80,7 +82,7 @@ static struct sk_buff *ar9331_tag_rcv(struct sk_buff *skb, } static const struct dsa_device_ops ar9331_netdev_ops = { - .name = "ar9331", + .name = AR9331_NAME, .proto = DSA_TAG_PROTO_AR9331, .xmit = ar9331_tag_xmit, .rcv = ar9331_tag_rcv, @@ -88,5 +90,5 @@ static const struct dsa_device_ops ar9331_netdev_ops = { }; MODULE_LICENSE("GPL v2"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_AR9331); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_AR9331, AR9331_NAME); module_dsa_tag_driver(ar9331_netdev_ops); diff --git a/net/dsa/tag_brcm.c b/net/dsa/tag_brcm.c index 16889ea3e0a7..10239daa5745 100644 --- a/net/dsa/tag_brcm.c +++ b/net/dsa/tag_brcm.c @@ -10,7 +10,11 @@ #include <linux/list.h> #include <linux/slab.h> -#include "dsa_priv.h" +#include "tag.h" + +#define BRCM_NAME "brcm" +#define BRCM_LEGACY_NAME "brcm-legacy" +#define BRCM_PREPEND_NAME "brcm-prepend" /* Legacy Broadcom tag (6 bytes) */ #define BRCM_LEG_TAG_LEN 6 @@ -196,7 +200,7 @@ static struct sk_buff *brcm_tag_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops brcm_netdev_ops = { - .name = "brcm", + .name = BRCM_NAME, .proto = DSA_TAG_PROTO_BRCM, .xmit = brcm_tag_xmit, .rcv = brcm_tag_rcv, @@ -204,7 +208,7 @@ static const struct dsa_device_ops brcm_netdev_ops = { }; DSA_TAG_DRIVER(brcm_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM, BRCM_NAME); #endif #if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_LEGACY) @@ -273,7 +277,7 @@ static struct sk_buff *brcm_leg_tag_rcv(struct sk_buff *skb, } static const struct dsa_device_ops brcm_legacy_netdev_ops = { - .name = "brcm-legacy", + .name = BRCM_LEGACY_NAME, .proto = DSA_TAG_PROTO_BRCM_LEGACY, .xmit = brcm_leg_tag_xmit, .rcv = brcm_leg_tag_rcv, @@ -281,7 +285,7 @@ static const struct dsa_device_ops brcm_legacy_netdev_ops = { }; DSA_TAG_DRIVER(brcm_legacy_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_LEGACY); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_LEGACY, BRCM_LEGACY_NAME); #endif /* CONFIG_NET_DSA_TAG_BRCM_LEGACY */ #if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_PREPEND) @@ -300,7 +304,7 @@ static struct sk_buff *brcm_tag_rcv_prepend(struct sk_buff *skb, } static const struct dsa_device_ops brcm_prepend_netdev_ops = { - .name = "brcm-prepend", + .name = BRCM_PREPEND_NAME, .proto = DSA_TAG_PROTO_BRCM_PREPEND, .xmit = brcm_tag_xmit_prepend, .rcv = brcm_tag_rcv_prepend, @@ -308,7 +312,7 @@ static const struct dsa_device_ops brcm_prepend_netdev_ops = { }; DSA_TAG_DRIVER(brcm_prepend_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_PREPEND); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_PREPEND, BRCM_PREPEND_NAME); #endif static struct dsa_tag_driver *dsa_tag_driver_array[] = { diff --git a/net/dsa/tag_dsa.c b/net/dsa/tag_dsa.c index e4b6e3f2a3db..1fd7fa26db64 100644 --- a/net/dsa/tag_dsa.c +++ b/net/dsa/tag_dsa.c @@ -50,7 +50,10 @@ #include <linux/list.h> #include <linux/slab.h> -#include "dsa_priv.h" +#include "tag.h" + +#define DSA_NAME "dsa" +#define EDSA_NAME "edsa" #define DSA_HLEN 4 @@ -339,7 +342,7 @@ static struct sk_buff *dsa_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops dsa_netdev_ops = { - .name = "dsa", + .name = DSA_NAME, .proto = DSA_TAG_PROTO_DSA, .xmit = dsa_xmit, .rcv = dsa_rcv, @@ -347,7 +350,7 @@ static const struct dsa_device_ops dsa_netdev_ops = { }; DSA_TAG_DRIVER(dsa_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_DSA); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_DSA, DSA_NAME); #endif /* CONFIG_NET_DSA_TAG_DSA */ #if IS_ENABLED(CONFIG_NET_DSA_TAG_EDSA) @@ -381,7 +384,7 @@ static struct sk_buff *edsa_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops edsa_netdev_ops = { - .name = "edsa", + .name = EDSA_NAME, .proto = DSA_TAG_PROTO_EDSA, .xmit = edsa_xmit, .rcv = edsa_rcv, @@ -389,7 +392,7 @@ static const struct dsa_device_ops edsa_netdev_ops = { }; DSA_TAG_DRIVER(edsa_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_EDSA); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_EDSA, EDSA_NAME); #endif /* CONFIG_NET_DSA_TAG_EDSA */ static struct dsa_tag_driver *dsa_tag_drivers[] = { diff --git a/net/dsa/tag_gswip.c b/net/dsa/tag_gswip.c index df7140984da3..e279cd9057b0 100644 --- a/net/dsa/tag_gswip.c +++ b/net/dsa/tag_gswip.c @@ -10,7 +10,9 @@ #include <linux/skbuff.h> #include <net/dsa.h> -#include "dsa_priv.h" +#include "tag.h" + +#define GSWIP_NAME "gswip" #define GSWIP_TX_HEADER_LEN 4 @@ -98,7 +100,7 @@ static struct sk_buff *gswip_tag_rcv(struct sk_buff *skb, } static const struct dsa_device_ops gswip_netdev_ops = { - .name = "gswip", + .name = GSWIP_NAME, .proto = DSA_TAG_PROTO_GSWIP, .xmit = gswip_tag_xmit, .rcv = gswip_tag_rcv, @@ -106,6 +108,6 @@ static const struct dsa_device_ops gswip_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_GSWIP); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_GSWIP, GSWIP_NAME); module_dsa_tag_driver(gswip_netdev_ops); diff --git a/net/dsa/tag_hellcreek.c b/net/dsa/tag_hellcreek.c index 846588c0070a..03a1fb9c87a9 100644 --- a/net/dsa/tag_hellcreek.c +++ b/net/dsa/tag_hellcreek.c @@ -11,7 +11,9 @@ #include <linux/skbuff.h> #include <net/dsa.h> -#include "dsa_priv.h" +#include "tag.h" + +#define HELLCREEK_NAME "hellcreek" #define HELLCREEK_TAG_LEN 1 @@ -49,7 +51,8 @@ static struct sk_buff *hellcreek_rcv(struct sk_buff *skb, return NULL; } - pskb_trim_rcsum(skb, skb->len - HELLCREEK_TAG_LEN); + if (pskb_trim_rcsum(skb, skb->len - HELLCREEK_TAG_LEN)) + return NULL; dsa_default_offload_fwd_mark(skb); @@ -57,7 +60,7 @@ static struct sk_buff *hellcreek_rcv(struct sk_buff *skb, } static const struct dsa_device_ops hellcreek_netdev_ops = { - .name = "hellcreek", + .name = HELLCREEK_NAME, .proto = DSA_TAG_PROTO_HELLCREEK, .xmit = hellcreek_xmit, .rcv = hellcreek_rcv, @@ -65,6 +68,6 @@ static const struct dsa_device_ops hellcreek_netdev_ops = { }; MODULE_LICENSE("Dual MIT/GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_HELLCREEK); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_HELLCREEK, HELLCREEK_NAME); module_dsa_tag_driver(hellcreek_netdev_ops); diff --git a/net/dsa/tag_ksz.c b/net/dsa/tag_ksz.c index 38fa19c1e2d5..080e5c369f5b 100644 --- a/net/dsa/tag_ksz.c +++ b/net/dsa/tag_ksz.c @@ -7,7 +7,13 @@ #include <linux/etherdevice.h> #include <linux/list.h> #include <net/dsa.h> -#include "dsa_priv.h" + +#include "tag.h" + +#define KSZ8795_NAME "ksz8795" +#define KSZ9477_NAME "ksz9477" +#define KSZ9893_NAME "ksz9893" +#define LAN937X_NAME "lan937x" /* Typically only one byte is used for tail tag. */ #define KSZ_EGRESS_TAG_LEN 1 @@ -21,7 +27,8 @@ static struct sk_buff *ksz_common_rcv(struct sk_buff *skb, if (!skb->dev) return NULL; - pskb_trim_rcsum(skb, skb->len - len); + if (pskb_trim_rcsum(skb, skb->len - len)) + return NULL; dsa_default_offload_fwd_mark(skb); @@ -74,7 +81,7 @@ static struct sk_buff *ksz8795_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops ksz8795_netdev_ops = { - .name = "ksz8795", + .name = KSZ8795_NAME, .proto = DSA_TAG_PROTO_KSZ8795, .xmit = ksz8795_xmit, .rcv = ksz8795_rcv, @@ -82,7 +89,7 @@ static const struct dsa_device_ops ksz8795_netdev_ops = { }; DSA_TAG_DRIVER(ksz8795_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ8795); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ8795, KSZ8795_NAME); /* * For Ingress (Host -> KSZ9477), 2 bytes are added before FCS. @@ -147,7 +154,7 @@ static struct sk_buff *ksz9477_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops ksz9477_netdev_ops = { - .name = "ksz9477", + .name = KSZ9477_NAME, .proto = DSA_TAG_PROTO_KSZ9477, .xmit = ksz9477_xmit, .rcv = ksz9477_rcv, @@ -155,7 +162,7 @@ static const struct dsa_device_ops ksz9477_netdev_ops = { }; DSA_TAG_DRIVER(ksz9477_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9477); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9477, KSZ9477_NAME); #define KSZ9893_TAIL_TAG_OVERRIDE BIT(5) #define KSZ9893_TAIL_TAG_LOOKUP BIT(6) @@ -183,7 +190,7 @@ static struct sk_buff *ksz9893_xmit(struct sk_buff *skb, } static const struct dsa_device_ops ksz9893_netdev_ops = { - .name = "ksz9893", + .name = KSZ9893_NAME, .proto = DSA_TAG_PROTO_KSZ9893, .xmit = ksz9893_xmit, .rcv = ksz9477_rcv, @@ -191,7 +198,7 @@ static const struct dsa_device_ops ksz9893_netdev_ops = { }; DSA_TAG_DRIVER(ksz9893_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9893); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9893, KSZ9893_NAME); /* For xmit, 2 bytes are added before FCS. * --------------------------------------------------------------------------- @@ -241,7 +248,7 @@ static struct sk_buff *lan937x_xmit(struct sk_buff *skb, } static const struct dsa_device_ops lan937x_netdev_ops = { - .name = "lan937x", + .name = LAN937X_NAME, .proto = DSA_TAG_PROTO_LAN937X, .xmit = lan937x_xmit, .rcv = ksz9477_rcv, @@ -249,7 +256,7 @@ static const struct dsa_device_ops lan937x_netdev_ops = { }; DSA_TAG_DRIVER(lan937x_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN937X); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN937X, LAN937X_NAME); static struct dsa_tag_driver *dsa_tag_driver_array[] = { &DSA_TAG_DRIVER_NAME(ksz8795_netdev_ops), diff --git a/net/dsa/tag_lan9303.c b/net/dsa/tag_lan9303.c index 98d7d7120bab..c25f5536706b 100644 --- a/net/dsa/tag_lan9303.c +++ b/net/dsa/tag_lan9303.c @@ -7,7 +7,7 @@ #include <linux/list.h> #include <linux/slab.h> -#include "dsa_priv.h" +#include "tag.h" /* To define the outgoing port and to discover the incoming port a regular * VLAN tag is used by the LAN9303. But its VID meaning is 'special': @@ -30,6 +30,8 @@ * Required when no forwarding between the external ports should happen. */ +#define LAN9303_NAME "lan9303" + #define LAN9303_TAG_LEN 4 # define LAN9303_TAG_TX_USE_ALR BIT(3) # define LAN9303_TAG_TX_STP_OVERRIDE BIT(4) @@ -110,7 +112,7 @@ static struct sk_buff *lan9303_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops lan9303_netdev_ops = { - .name = "lan9303", + .name = LAN9303_NAME, .proto = DSA_TAG_PROTO_LAN9303, .xmit = lan9303_xmit, .rcv = lan9303_rcv, @@ -118,6 +120,6 @@ static const struct dsa_device_ops lan9303_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN9303); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN9303, LAN9303_NAME); module_dsa_tag_driver(lan9303_netdev_ops); diff --git a/net/dsa/tag_mtk.c b/net/dsa/tag_mtk.c index 415d8ece242a..40af80452747 100644 --- a/net/dsa/tag_mtk.c +++ b/net/dsa/tag_mtk.c @@ -8,7 +8,9 @@ #include <linux/etherdevice.h> #include <linux/if_vlan.h> -#include "dsa_priv.h" +#include "tag.h" + +#define MTK_NAME "mtk" #define MTK_HDR_LEN 4 #define MTK_HDR_XMIT_UNTAGGED 0 @@ -25,6 +27,8 @@ static struct sk_buff *mtk_tag_xmit(struct sk_buff *skb, u8 xmit_tpid; u8 *mtk_tag; + skb_set_queue_mapping(skb, dp->index); + /* Build the special tag after the MAC Source Address. If VLAN header * is present, it's required that VLAN header and special tag is * being combined. Only in this way we can allow the switch can parse @@ -91,7 +95,7 @@ static struct sk_buff *mtk_tag_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops mtk_netdev_ops = { - .name = "mtk", + .name = MTK_NAME, .proto = DSA_TAG_PROTO_MTK, .xmit = mtk_tag_xmit, .rcv = mtk_tag_rcv, @@ -99,6 +103,6 @@ static const struct dsa_device_ops mtk_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_MTK); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_MTK, MTK_NAME); module_dsa_tag_driver(mtk_netdev_ops); diff --git a/net/dsa/tag_none.c b/net/dsa/tag_none.c new file mode 100644 index 000000000000..d2fd179c4227 --- /dev/null +++ b/net/dsa/tag_none.c @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dsa/tag_none.c - Traffic handling for switches with no tag + * Copyright (c) 2008-2009 Marvell Semiconductor + * Copyright (c) 2013 Florian Fainelli <florian@openwrt.org> + * + * WARNING: do not use this for new switches. In case of no hardware + * tagging support, look at tag_8021q.c instead. + */ + +#include "tag.h" + +#define NONE_NAME "none" + +static struct sk_buff *dsa_slave_notag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + /* Just return the original SKB */ + return skb; +} + +static const struct dsa_device_ops none_ops = { + .name = NONE_NAME, + .proto = DSA_TAG_PROTO_NONE, + .xmit = dsa_slave_notag_xmit, +}; + +module_dsa_tag_driver(none_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_NONE, NONE_NAME); +MODULE_LICENSE("GPL"); diff --git a/net/dsa/tag_ocelot.c b/net/dsa/tag_ocelot.c index 0d81f172b7a6..28ebecafdd24 100644 --- a/net/dsa/tag_ocelot.c +++ b/net/dsa/tag_ocelot.c @@ -2,7 +2,11 @@ /* Copyright 2019 NXP */ #include <linux/dsa/ocelot.h> -#include "dsa_priv.h" + +#include "tag.h" + +#define OCELOT_NAME "ocelot" +#define SEVILLE_NAME "seville" /* If the port is under a VLAN-aware bridge, remove the VLAN header from the * payload and move it into the DSA tag, which will make the switch classify @@ -183,7 +187,7 @@ static struct sk_buff *ocelot_rcv(struct sk_buff *skb, } static const struct dsa_device_ops ocelot_netdev_ops = { - .name = "ocelot", + .name = OCELOT_NAME, .proto = DSA_TAG_PROTO_OCELOT, .xmit = ocelot_xmit, .rcv = ocelot_rcv, @@ -192,10 +196,10 @@ static const struct dsa_device_ops ocelot_netdev_ops = { }; DSA_TAG_DRIVER(ocelot_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT, OCELOT_NAME); static const struct dsa_device_ops seville_netdev_ops = { - .name = "seville", + .name = SEVILLE_NAME, .proto = DSA_TAG_PROTO_SEVILLE, .xmit = seville_xmit, .rcv = ocelot_rcv, @@ -204,7 +208,7 @@ static const struct dsa_device_ops seville_netdev_ops = { }; DSA_TAG_DRIVER(seville_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SEVILLE); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SEVILLE, SEVILLE_NAME); static struct dsa_tag_driver *ocelot_tag_driver_array[] = { &DSA_TAG_DRIVER_NAME(ocelot_netdev_ops), diff --git a/net/dsa/tag_ocelot_8021q.c b/net/dsa/tag_ocelot_8021q.c index 37ccf00404ea..1f0b8c20eba5 100644 --- a/net/dsa/tag_ocelot_8021q.c +++ b/net/dsa/tag_ocelot_8021q.c @@ -10,7 +10,11 @@ */ #include <linux/dsa/8021q.h> #include <linux/dsa/ocelot.h> -#include "dsa_priv.h" + +#include "tag.h" +#include "tag_8021q.h" + +#define OCELOT_8021Q_NAME "ocelot-8021q" struct ocelot_8021q_tagger_private { struct ocelot_8021q_tagger_data data; /* Must be first */ @@ -119,7 +123,7 @@ static int ocelot_connect(struct dsa_switch *ds) } static const struct dsa_device_ops ocelot_8021q_netdev_ops = { - .name = "ocelot-8021q", + .name = OCELOT_8021Q_NAME, .proto = DSA_TAG_PROTO_OCELOT_8021Q, .xmit = ocelot_xmit, .rcv = ocelot_rcv, @@ -130,6 +134,6 @@ static const struct dsa_device_ops ocelot_8021q_netdev_ops = { }; MODULE_LICENSE("GPL v2"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT_8021Q); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT_8021Q, OCELOT_8021Q_NAME); module_dsa_tag_driver(ocelot_8021q_netdev_ops); diff --git a/net/dsa/tag_qca.c b/net/dsa/tag_qca.c index 57d2e00f1e5d..e757c8de06f1 100644 --- a/net/dsa/tag_qca.c +++ b/net/dsa/tag_qca.c @@ -8,7 +8,9 @@ #include <net/dsa.h> #include <linux/dsa/tag_qca.h> -#include "dsa_priv.h" +#include "tag.h" + +#define QCA_NAME "qca" static struct sk_buff *qca_tag_xmit(struct sk_buff *skb, struct net_device *dev) { @@ -107,7 +109,7 @@ static void qca_tag_disconnect(struct dsa_switch *ds) } static const struct dsa_device_ops qca_netdev_ops = { - .name = "qca", + .name = QCA_NAME, .proto = DSA_TAG_PROTO_QCA, .connect = qca_tag_connect, .disconnect = qca_tag_disconnect, @@ -118,6 +120,6 @@ static const struct dsa_device_ops qca_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_QCA); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_QCA, QCA_NAME); module_dsa_tag_driver(qca_netdev_ops); diff --git a/net/dsa/tag_rtl4_a.c b/net/dsa/tag_rtl4_a.c index 6d928ee3ef7a..c327314b95e3 100644 --- a/net/dsa/tag_rtl4_a.c +++ b/net/dsa/tag_rtl4_a.c @@ -18,7 +18,9 @@ #include <linux/etherdevice.h> #include <linux/bits.h> -#include "dsa_priv.h" +#include "tag.h" + +#define RTL4_A_NAME "rtl4a" #define RTL4_A_HDR_LEN 4 #define RTL4_A_ETHERTYPE 0x8899 @@ -112,7 +114,7 @@ static struct sk_buff *rtl4a_tag_rcv(struct sk_buff *skb, } static const struct dsa_device_ops rtl4a_netdev_ops = { - .name = "rtl4a", + .name = RTL4_A_NAME, .proto = DSA_TAG_PROTO_RTL4_A, .xmit = rtl4a_tag_xmit, .rcv = rtl4a_tag_rcv, @@ -121,4 +123,4 @@ static const struct dsa_device_ops rtl4a_netdev_ops = { module_dsa_tag_driver(rtl4a_netdev_ops); MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL4_A); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL4_A, RTL4_A_NAME); diff --git a/net/dsa/tag_rtl8_4.c b/net/dsa/tag_rtl8_4.c index a593ead7ff26..4f67834fd121 100644 --- a/net/dsa/tag_rtl8_4.c +++ b/net/dsa/tag_rtl8_4.c @@ -77,13 +77,16 @@ #include <linux/bits.h> #include <linux/etherdevice.h> -#include "dsa_priv.h" +#include "tag.h" /* Protocols supported: * * 0x04 = RTL8365MB DSA protocol */ +#define RTL8_4_NAME "rtl8_4" +#define RTL8_4T_NAME "rtl8_4t" + #define RTL8_4_TAG_LEN 8 #define RTL8_4_PROTOCOL GENMASK(15, 8) @@ -234,7 +237,7 @@ static const struct dsa_device_ops rtl8_4_netdev_ops = { DSA_TAG_DRIVER(rtl8_4_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4, RTL8_4_NAME); /* Tail version */ static const struct dsa_device_ops rtl8_4t_netdev_ops = { @@ -247,7 +250,7 @@ static const struct dsa_device_ops rtl8_4t_netdev_ops = { DSA_TAG_DRIVER(rtl8_4t_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4T); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4T, RTL8_4T_NAME); static struct dsa_tag_driver *dsa_tag_drivers[] = { &DSA_TAG_DRIVER_NAME(rtl8_4_netdev_ops), diff --git a/net/dsa/tag_rzn1_a5psw.c b/net/dsa/tag_rzn1_a5psw.c index e2a5ee6ae688..437a6820ac42 100644 --- a/net/dsa/tag_rzn1_a5psw.c +++ b/net/dsa/tag_rzn1_a5psw.c @@ -10,7 +10,7 @@ #include <linux/if_ether.h> #include <net/dsa.h> -#include "dsa_priv.h" +#include "tag.h" /* To define the outgoing port and to discover the incoming port a TAG is * inserted after Src MAC : @@ -22,6 +22,8 @@ * See struct a5psw_tag for layout */ +#define A5PSW_NAME "a5psw" + #define ETH_P_DSA_A5PSW 0xE001 #define A5PSW_TAG_LEN 8 #define A5PSW_CTRL_DATA_FORCE_FORWARD BIT(0) @@ -101,7 +103,7 @@ static struct sk_buff *a5psw_tag_rcv(struct sk_buff *skb, } static const struct dsa_device_ops a5psw_netdev_ops = { - .name = "a5psw", + .name = A5PSW_NAME, .proto = DSA_TAG_PROTO_RZN1_A5PSW, .xmit = a5psw_tag_xmit, .rcv = a5psw_tag_rcv, @@ -109,5 +111,5 @@ static const struct dsa_device_ops a5psw_netdev_ops = { }; MODULE_LICENSE("GPL v2"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_A5PSW); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_A5PSW, A5PSW_NAME); module_dsa_tag_driver(a5psw_netdev_ops); diff --git a/net/dsa/tag_sja1105.c b/net/dsa/tag_sja1105.c index 83e4136516b0..1c2ceba4771b 100644 --- a/net/dsa/tag_sja1105.c +++ b/net/dsa/tag_sja1105.c @@ -5,7 +5,12 @@ #include <linux/dsa/sja1105.h> #include <linux/dsa/8021q.h> #include <linux/packing.h> -#include "dsa_priv.h" + +#include "tag.h" +#include "tag_8021q.h" + +#define SJA1105_NAME "sja1105" +#define SJA1110_NAME "sja1110" /* Is this a TX or an RX header? */ #define SJA1110_HEADER_HOST_TO_SWITCH BIT(15) @@ -665,7 +670,8 @@ static struct sk_buff *sja1110_rcv_inband_control_extension(struct sk_buff *skb, * padding and trailer we need to account for the fact that * skb->data points to skb_mac_header(skb) + ETH_HLEN. */ - pskb_trim_rcsum(skb, start_of_padding - ETH_HLEN); + if (pskb_trim_rcsum(skb, start_of_padding - ETH_HLEN)) + return NULL; /* Trap-to-host frame, no timestamp trailer */ } else { *source_port = SJA1110_RX_HEADER_SRC_PORT(rx_header); @@ -786,7 +792,7 @@ static int sja1105_connect(struct dsa_switch *ds) } static const struct dsa_device_ops sja1105_netdev_ops = { - .name = "sja1105", + .name = SJA1105_NAME, .proto = DSA_TAG_PROTO_SJA1105, .xmit = sja1105_xmit, .rcv = sja1105_rcv, @@ -798,10 +804,10 @@ static const struct dsa_device_ops sja1105_netdev_ops = { }; DSA_TAG_DRIVER(sja1105_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1105); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1105, SJA1105_NAME); static const struct dsa_device_ops sja1110_netdev_ops = { - .name = "sja1110", + .name = SJA1110_NAME, .proto = DSA_TAG_PROTO_SJA1110, .xmit = sja1110_xmit, .rcv = sja1110_rcv, @@ -813,7 +819,7 @@ static const struct dsa_device_ops sja1110_netdev_ops = { }; DSA_TAG_DRIVER(sja1110_netdev_ops); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1110); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1110, SJA1110_NAME); static struct dsa_tag_driver *sja1105_tag_driver_array[] = { &DSA_TAG_DRIVER_NAME(sja1105_netdev_ops), diff --git a/net/dsa/tag_trailer.c b/net/dsa/tag_trailer.c index 5749ba85c2b8..7361b9106382 100644 --- a/net/dsa/tag_trailer.c +++ b/net/dsa/tag_trailer.c @@ -8,7 +8,9 @@ #include <linux/list.h> #include <linux/slab.h> -#include "dsa_priv.h" +#include "tag.h" + +#define TRAILER_NAME "trailer" static struct sk_buff *trailer_xmit(struct sk_buff *skb, struct net_device *dev) { @@ -50,7 +52,7 @@ static struct sk_buff *trailer_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops trailer_netdev_ops = { - .name = "trailer", + .name = TRAILER_NAME, .proto = DSA_TAG_PROTO_TRAILER, .xmit = trailer_xmit, .rcv = trailer_rcv, @@ -58,6 +60,6 @@ static const struct dsa_device_ops trailer_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_TRAILER); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_TRAILER, TRAILER_NAME); module_dsa_tag_driver(trailer_netdev_ops); diff --git a/net/dsa/tag_xrs700x.c b/net/dsa/tag_xrs700x.c index ff442b8af636..af19969f9bc4 100644 --- a/net/dsa/tag_xrs700x.c +++ b/net/dsa/tag_xrs700x.c @@ -7,7 +7,9 @@ #include <linux/bitops.h> -#include "dsa_priv.h" +#include "tag.h" + +#define XRS700X_NAME "xrs700x" static struct sk_buff *xrs700x_xmit(struct sk_buff *skb, struct net_device *dev) { @@ -51,7 +53,7 @@ static struct sk_buff *xrs700x_rcv(struct sk_buff *skb, struct net_device *dev) } static const struct dsa_device_ops xrs700x_netdev_ops = { - .name = "xrs700x", + .name = XRS700X_NAME, .proto = DSA_TAG_PROTO_XRS700X, .xmit = xrs700x_xmit, .rcv = xrs700x_rcv, @@ -59,6 +61,6 @@ static const struct dsa_device_ops xrs700x_netdev_ops = { }; MODULE_LICENSE("GPL"); -MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_XRS700X); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_XRS700X, XRS700X_NAME); module_dsa_tag_driver(xrs700x_netdev_ops); diff --git a/net/ethernet/eth.c b/net/ethernet/eth.c index e02daa74e833..2edc8b796a4e 100644 --- a/net/ethernet/eth.c +++ b/net/ethernet/eth.c @@ -398,7 +398,7 @@ EXPORT_SYMBOL(alloc_etherdev_mqs); ssize_t sysfs_format_mac(char *buf, const unsigned char *addr, int len) { - return scnprintf(buf, PAGE_SIZE, "%*phC\n", len, addr); + return sysfs_emit(buf, "%*phC\n", len, addr); } EXPORT_SYMBOL(sysfs_format_mac); diff --git a/net/ethtool/Makefile b/net/ethtool/Makefile index 72ab0944262a..228f13df2e18 100644 --- a/net/ethtool/Makefile +++ b/net/ethtool/Makefile @@ -4,7 +4,7 @@ obj-y += ioctl.o common.o obj-$(CONFIG_ETHTOOL_NETLINK) += ethtool_nl.o -ethtool_nl-y := netlink.o bitset.o strset.o linkinfo.o linkmodes.o \ +ethtool_nl-y := netlink.o bitset.o strset.o linkinfo.o linkmodes.o rss.o \ linkstate.o debug.o wol.o features.o privflags.o rings.o \ channels.o coalesce.o pause.o eee.o tsinfo.o cabletest.o \ tunnels.o fec.o eeprom.o stats.o phc_vclocks.o module.o \ diff --git a/net/ethtool/channels.c b/net/ethtool/channels.c index 403158862011..c7e37130647e 100644 --- a/net/ethtool/channels.c +++ b/net/ethtool/channels.c @@ -116,9 +116,10 @@ int ethnl_set_channels(struct sk_buff *skb, struct genl_info *info) struct ethtool_channels channels = {}; struct ethnl_req_info req_info = {}; struct nlattr **tb = info->attrs; - u32 err_attr, max_rx_in_use = 0; + u32 err_attr, max_rxfh_in_use; const struct ethtool_ops *ops; struct net_device *dev; + u64 max_rxnfc_in_use; int ret; ret = ethnl_parse_header_dev_get(&req_info, @@ -189,15 +190,23 @@ int ethnl_set_channels(struct sk_buff *skb, struct genl_info *info) } /* ensure the new Rx count fits within the configured Rx flow - * indirection table settings + * indirection table/rxnfc settings */ - if (netif_is_rxfh_configured(dev) && - !ethtool_get_max_rxfh_channel(dev, &max_rx_in_use) && - (channels.combined_count + channels.rx_count) <= max_rx_in_use) { + if (ethtool_get_max_rxnfc_channel(dev, &max_rxnfc_in_use)) + max_rxnfc_in_use = 0; + if (!netif_is_rxfh_configured(dev) || + ethtool_get_max_rxfh_channel(dev, &max_rxfh_in_use)) + max_rxfh_in_use = 0; + if (channels.combined_count + channels.rx_count <= max_rxfh_in_use) { ret = -EINVAL; GENL_SET_ERR_MSG(info, "requested channel counts are too low for existing indirection table settings"); goto out_ops; } + if (channels.combined_count + channels.rx_count <= max_rxnfc_in_use) { + ret = -EINVAL; + GENL_SET_ERR_MSG(info, "requested channel counts are too low for existing ntuple filter settings"); + goto out_ops; + } /* Disabling channels, query zero-copy AF_XDP sockets */ from_channel = channels.combined_count + diff --git a/net/ethtool/common.c b/net/ethtool/common.c index 566adf85e658..6f399afc2ff2 100644 --- a/net/ethtool/common.c +++ b/net/ethtool/common.c @@ -202,6 +202,12 @@ const char link_mode_names[][ETH_GSTRING_LEN] = { __DEFINE_LINK_MODE_NAME(100, FX, Half), __DEFINE_LINK_MODE_NAME(100, FX, Full), __DEFINE_LINK_MODE_NAME(10, T1L, Full), + __DEFINE_LINK_MODE_NAME(800000, CR8, Full), + __DEFINE_LINK_MODE_NAME(800000, KR8, Full), + __DEFINE_LINK_MODE_NAME(800000, DR8, Full), + __DEFINE_LINK_MODE_NAME(800000, DR8_2, Full), + __DEFINE_LINK_MODE_NAME(800000, SR8, Full), + __DEFINE_LINK_MODE_NAME(800000, VR8, Full), }; static_assert(ARRAY_SIZE(link_mode_names) == __ETHTOOL_LINK_MODE_MASK_NBITS); @@ -238,6 +244,8 @@ static_assert(ARRAY_SIZE(link_mode_names) == __ETHTOOL_LINK_MODE_MASK_NBITS); #define __LINK_MODE_LANES_X 1 #define __LINK_MODE_LANES_FX 1 #define __LINK_MODE_LANES_T1L 1 +#define __LINK_MODE_LANES_VR8 8 +#define __LINK_MODE_LANES_DR8_2 8 #define __DEFINE_LINK_MODE_PARAMS(_speed, _type, _duplex) \ [ETHTOOL_LINK_MODE(_speed, _type, _duplex)] = { \ @@ -352,6 +360,12 @@ const struct link_mode_info link_mode_params[] = { __DEFINE_LINK_MODE_PARAMS(100, FX, Half), __DEFINE_LINK_MODE_PARAMS(100, FX, Full), __DEFINE_LINK_MODE_PARAMS(10, T1L, Full), + __DEFINE_LINK_MODE_PARAMS(800000, CR8, Full), + __DEFINE_LINK_MODE_PARAMS(800000, KR8, Full), + __DEFINE_LINK_MODE_PARAMS(800000, DR8, Full), + __DEFINE_LINK_MODE_PARAMS(800000, DR8_2, Full), + __DEFINE_LINK_MODE_PARAMS(800000, SR8, Full), + __DEFINE_LINK_MODE_PARAMS(800000, VR8, Full), }; static_assert(ARRAY_SIZE(link_mode_params) == __ETHTOOL_LINK_MODE_MASK_NBITS); @@ -403,6 +417,7 @@ const char sof_timestamping_names[][ETH_GSTRING_LEN] = { [const_ilog2(SOF_TIMESTAMPING_OPT_PKTINFO)] = "option-pktinfo", [const_ilog2(SOF_TIMESTAMPING_OPT_TX_SWHW)] = "option-tx-swhw", [const_ilog2(SOF_TIMESTAMPING_BIND_PHC)] = "bind-phc", + [const_ilog2(SOF_TIMESTAMPING_OPT_ID_TCP)] = "option-id-tcp", }; static_assert(ARRAY_SIZE(sof_timestamping_names) == __SOF_TIMESTAMPING_CNT); @@ -498,6 +513,72 @@ int __ethtool_get_link(struct net_device *dev) return netif_running(dev) && dev->ethtool_ops->get_link(dev); } +static int ethtool_get_rxnfc_rule_count(struct net_device *dev) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + struct ethtool_rxnfc info = { + .cmd = ETHTOOL_GRXCLSRLCNT, + }; + int err; + + err = ops->get_rxnfc(dev, &info, NULL); + if (err) + return err; + + return info.rule_cnt; +} + +int ethtool_get_max_rxnfc_channel(struct net_device *dev, u64 *max) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + struct ethtool_rxnfc *info; + int err, i, rule_cnt; + u64 max_ring = 0; + + if (!ops->get_rxnfc) + return -EOPNOTSUPP; + + rule_cnt = ethtool_get_rxnfc_rule_count(dev); + if (rule_cnt <= 0) + return -EINVAL; + + info = kvzalloc(struct_size(info, rule_locs, rule_cnt), GFP_KERNEL); + if (!info) + return -ENOMEM; + + info->cmd = ETHTOOL_GRXCLSRLALL; + info->rule_cnt = rule_cnt; + err = ops->get_rxnfc(dev, info, info->rule_locs); + if (err) + goto err_free_info; + + for (i = 0; i < rule_cnt; i++) { + struct ethtool_rxnfc rule_info = { + .cmd = ETHTOOL_GRXCLSRULE, + .fs.location = info->rule_locs[i], + }; + + err = ops->get_rxnfc(dev, &rule_info, NULL); + if (err) + goto err_free_info; + + if (rule_info.fs.ring_cookie != RX_CLS_FLOW_DISC && + rule_info.fs.ring_cookie != RX_CLS_FLOW_WAKE && + !(rule_info.flow_type & FLOW_RSS) && + !ethtool_get_flow_spec_ring_vf(rule_info.fs.ring_cookie)) + max_ring = + max_t(u64, max_ring, rule_info.fs.ring_cookie); + } + + kvfree(info); + *max = max_ring; + return 0; + +err_free_info: + kvfree(info); + return err; +} + int ethtool_get_max_rxfh_channel(struct net_device *dev, u32 *max) { u32 dev_size, current_max = 0; diff --git a/net/ethtool/common.h b/net/ethtool/common.h index c1779657e074..b1b9db810eca 100644 --- a/net/ethtool/common.h +++ b/net/ethtool/common.h @@ -43,6 +43,7 @@ bool convert_legacy_settings_to_link_ksettings( struct ethtool_link_ksettings *link_ksettings, const struct ethtool_cmd *legacy_settings); int ethtool_get_max_rxfh_channel(struct net_device *dev, u32 *max); +int ethtool_get_max_rxnfc_channel(struct net_device *dev, u64 *max); int __ethtool_get_ts_info(struct net_device *dev, struct ethtool_ts_info *info); extern const struct ethtool_phy_ops *ethtool_phy_ops; diff --git a/net/ethtool/ioctl.c b/net/ethtool/ioctl.c index 57e7238a4136..646b3e490c71 100644 --- a/net/ethtool/ioctl.c +++ b/net/ethtool/ioctl.c @@ -44,16 +44,9 @@ struct ethtool_devlink_compat { static struct devlink *netdev_to_devlink_get(struct net_device *dev) { - struct devlink_port *devlink_port; - - if (!dev->netdev_ops->ndo_get_devlink_port) - return NULL; - - devlink_port = dev->netdev_ops->ndo_get_devlink_port(dev); - if (!devlink_port) + if (!dev->devlink_port) return NULL; - - return devlink_try_get(devlink_port->devlink); + return devlink_try_get(dev->devlink_port->devlink); } /* @@ -713,15 +706,22 @@ static int ethtool_get_drvinfo(struct net_device *dev, struct ethtool_devlink_compat *rsp) { const struct ethtool_ops *ops = dev->ethtool_ops; + struct device *parent = dev->dev.parent; rsp->info.cmd = ETHTOOL_GDRVINFO; strscpy(rsp->info.version, UTS_RELEASE, sizeof(rsp->info.version)); if (ops->get_drvinfo) { ops->get_drvinfo(dev, &rsp->info); - } else if (dev->dev.parent && dev->dev.parent->driver) { - strscpy(rsp->info.bus_info, dev_name(dev->dev.parent), + if (!rsp->info.bus_info[0] && parent) + strscpy(rsp->info.bus_info, dev_name(parent), + sizeof(rsp->info.bus_info)); + if (!rsp->info.driver[0] && parent && parent->driver) + strscpy(rsp->info.driver, parent->driver->name, + sizeof(rsp->info.driver)); + } else if (parent && parent->driver) { + strscpy(rsp->info.bus_info, dev_name(parent), sizeof(rsp->info.bus_info)); - strscpy(rsp->info.driver, dev->dev.parent->driver->name, + strscpy(rsp->info.driver, parent->driver->name, sizeof(rsp->info.driver)); } else if (dev->rtnl_link_ops) { strscpy(rsp->info.driver, dev->rtnl_link_ops->kind, @@ -1796,7 +1796,8 @@ static noinline_for_stack int ethtool_set_channels(struct net_device *dev, { struct ethtool_channels channels, curr = { .cmd = ETHTOOL_GCHANNELS }; u16 from_channel, to_channel; - u32 max_rx_in_use = 0; + u64 max_rxnfc_in_use; + u32 max_rxfh_in_use; unsigned int i; int ret; @@ -1827,11 +1828,15 @@ static noinline_for_stack int ethtool_set_channels(struct net_device *dev, return -EINVAL; /* ensure the new Rx count fits within the configured Rx flow - * indirection table settings */ - if (netif_is_rxfh_configured(dev) && - !ethtool_get_max_rxfh_channel(dev, &max_rx_in_use) && - (channels.combined_count + channels.rx_count) <= max_rx_in_use) - return -EINVAL; + * indirection table/rxnfc settings */ + if (ethtool_get_max_rxnfc_channel(dev, &max_rxnfc_in_use)) + max_rxnfc_in_use = 0; + if (!netif_is_rxfh_configured(dev) || + ethtool_get_max_rxfh_channel(dev, &max_rxfh_in_use)) + max_rxfh_in_use = 0; + if (channels.combined_count + channels.rx_count <= + max_t(u64, max_rxnfc_in_use, max_rxfh_in_use)) + return -EINVAL; /* Disabling channels, query zero-copy AF_XDP sockets */ from_channel = channels.combined_count + @@ -2008,7 +2013,8 @@ static int ethtool_phys_id(struct net_device *dev, void __user *useraddr) } else { /* Driver expects to be called at twice the frequency in rc */ int n = rc * 2, interval = HZ / n; - u64 count = n * id.data, i = 0; + u64 count = mul_u32_u32(n, id.data); + u64 i = 0; do { rtnl_lock(); @@ -2072,58 +2078,91 @@ static int ethtool_get_stats(struct net_device *dev, void __user *useraddr) return ret; } -static int ethtool_get_phy_stats(struct net_device *dev, void __user *useraddr) +static int ethtool_vzalloc_stats_array(int n_stats, u64 **data) { + if (n_stats < 0) + return n_stats; + if (n_stats > S32_MAX / sizeof(u64)) + return -ENOMEM; + if (WARN_ON_ONCE(!n_stats)) + return -EOPNOTSUPP; + + *data = vzalloc(array_size(n_stats, sizeof(u64))); + if (!*data) + return -ENOMEM; + + return 0; +} + +static int ethtool_get_phy_stats_phydev(struct phy_device *phydev, + struct ethtool_stats *stats, + u64 **data) + { const struct ethtool_phy_ops *phy_ops = ethtool_phy_ops; + int n_stats, ret; + + if (!phy_ops || !phy_ops->get_sset_count || !phy_ops->get_stats) + return -EOPNOTSUPP; + + n_stats = phy_ops->get_sset_count(phydev); + + ret = ethtool_vzalloc_stats_array(n_stats, data); + if (ret) + return ret; + + stats->n_stats = n_stats; + return phy_ops->get_stats(phydev, stats, *data); +} + +static int ethtool_get_phy_stats_ethtool(struct net_device *dev, + struct ethtool_stats *stats, + u64 **data) +{ const struct ethtool_ops *ops = dev->ethtool_ops; - struct phy_device *phydev = dev->phydev; - struct ethtool_stats stats; - u64 *data; - int ret, n_stats; + int n_stats, ret; - if (!phydev && (!ops->get_ethtool_phy_stats || !ops->get_sset_count)) + if (!ops || !ops->get_sset_count || ops->get_ethtool_phy_stats) return -EOPNOTSUPP; - if (phydev && !ops->get_ethtool_phy_stats && - phy_ops && phy_ops->get_sset_count) - n_stats = phy_ops->get_sset_count(phydev); - else - n_stats = ops->get_sset_count(dev, ETH_SS_PHY_STATS); - if (n_stats < 0) - return n_stats; - if (n_stats > S32_MAX / sizeof(u64)) - return -ENOMEM; - WARN_ON_ONCE(!n_stats); + n_stats = ops->get_sset_count(dev, ETH_SS_PHY_STATS); + + ret = ethtool_vzalloc_stats_array(n_stats, data); + if (ret) + return ret; + + stats->n_stats = n_stats; + ops->get_ethtool_phy_stats(dev, stats, *data); + + return 0; +} + +static int ethtool_get_phy_stats(struct net_device *dev, void __user *useraddr) +{ + struct phy_device *phydev = dev->phydev; + struct ethtool_stats stats; + u64 *data = NULL; + int ret = -EOPNOTSUPP; if (copy_from_user(&stats, useraddr, sizeof(stats))) return -EFAULT; - stats.n_stats = n_stats; + if (phydev) + ret = ethtool_get_phy_stats_phydev(phydev, &stats, &data); - if (n_stats) { - data = vzalloc(array_size(n_stats, sizeof(u64))); - if (!data) - return -ENOMEM; + if (ret == -EOPNOTSUPP) + ret = ethtool_get_phy_stats_ethtool(dev, &stats, &data); - if (phydev && !ops->get_ethtool_phy_stats && - phy_ops && phy_ops->get_stats) { - ret = phy_ops->get_stats(phydev, &stats, data); - if (ret < 0) - goto out; - } else { - ops->get_ethtool_phy_stats(dev, &stats, data); - } - } else { - data = NULL; - } + if (ret) + goto out; - ret = -EFAULT; - if (copy_to_user(useraddr, &stats, sizeof(stats))) + if (copy_to_user(useraddr, &stats, sizeof(stats))) { + ret = -EFAULT; goto out; + } + useraddr += sizeof(stats); - if (n_stats && copy_to_user(useraddr, data, array_size(n_stats, sizeof(u64)))) - goto out; - ret = 0; + if (copy_to_user(useraddr, data, array_size(stats.n_stats, sizeof(u64)))) + ret = -EFAULT; out: vfree(data); diff --git a/net/ethtool/linkstate.c b/net/ethtool/linkstate.c index fb676f349455..2158c17a0b32 100644 --- a/net/ethtool/linkstate.c +++ b/net/ethtool/linkstate.c @@ -13,6 +13,7 @@ struct linkstate_reply_data { int link; int sqi; int sqi_max; + struct ethtool_link_ext_stats link_stats; bool link_ext_state_provided; struct ethtool_link_ext_state_info ethtool_link_ext_state_info; }; @@ -22,7 +23,7 @@ struct linkstate_reply_data { const struct nla_policy ethnl_linkstate_get_policy[] = { [ETHTOOL_A_LINKSTATE_HEADER] = - NLA_POLICY_NESTED(ethnl_header_policy), + NLA_POLICY_NESTED(ethnl_header_policy_stats), }; static int linkstate_get_sqi(struct net_device *dev) @@ -107,6 +108,19 @@ static int linkstate_prepare_data(const struct ethnl_req_info *req_base, goto out; } + ethtool_stats_init((u64 *)&data->link_stats, + sizeof(data->link_stats) / 8); + + if (req_base->flags & ETHTOOL_FLAG_STATS) { + if (dev->phydev) + data->link_stats.link_down_events = + READ_ONCE(dev->phydev->link_down_events); + + if (dev->ethtool_ops->get_link_ext_stats) + dev->ethtool_ops->get_link_ext_stats(dev, + &data->link_stats); + } + ret = 0; out: ethnl_ops_complete(dev); @@ -134,6 +148,9 @@ static int linkstate_reply_size(const struct ethnl_req_info *req_base, if (data->ethtool_link_ext_state_info.__link_ext_substate) len += nla_total_size(sizeof(u8)); /* LINKSTATE_EXT_SUBSTATE */ + if (data->link_stats.link_down_events != ETHTOOL_STAT_NOT_SET) + len += nla_total_size(sizeof(u32)); + return len; } @@ -166,6 +183,11 @@ static int linkstate_fill_reply(struct sk_buff *skb, return -EMSGSIZE; } + if (data->link_stats.link_down_events != ETHTOOL_STAT_NOT_SET) + if (nla_put_u32(skb, ETHTOOL_A_LINKSTATE_EXT_DOWN_CNT, + data->link_stats.link_down_events)) + return -EMSGSIZE; + return 0; } diff --git a/net/ethtool/netlink.c b/net/ethtool/netlink.c index 1a4c11356c96..aee98be6237f 100644 --- a/net/ethtool/netlink.c +++ b/net/ethtool/netlink.c @@ -287,6 +287,7 @@ ethnl_default_requests[__ETHTOOL_MSG_USER_CNT] = { [ETHTOOL_MSG_PHC_VCLOCKS_GET] = ðnl_phc_vclocks_request_ops, [ETHTOOL_MSG_MODULE_GET] = ðnl_module_request_ops, [ETHTOOL_MSG_PSE_GET] = ðnl_pse_request_ops, + [ETHTOOL_MSG_RSS_GET] = ðnl_rss_request_ops, }; static struct ethnl_dump_ctx *ethnl_dump_context(struct netlink_callback *cb) @@ -1040,6 +1041,12 @@ static const struct genl_ops ethtool_genl_ops[] = { .policy = ethnl_pse_set_policy, .maxattr = ARRAY_SIZE(ethnl_pse_set_policy) - 1, }, + { + .cmd = ETHTOOL_MSG_RSS_GET, + .doit = ethnl_default_doit, + .policy = ethnl_rss_get_policy, + .maxattr = ARRAY_SIZE(ethnl_rss_get_policy) - 1, + }, }; static const struct genl_multicast_group ethtool_nl_mcgrps[] = { diff --git a/net/ethtool/netlink.h b/net/ethtool/netlink.h index 1bfd374f9718..3753787ba233 100644 --- a/net/ethtool/netlink.h +++ b/net/ethtool/netlink.h @@ -346,6 +346,7 @@ extern const struct ethnl_request_ops ethnl_stats_request_ops; extern const struct ethnl_request_ops ethnl_phc_vclocks_request_ops; extern const struct ethnl_request_ops ethnl_module_request_ops; extern const struct ethnl_request_ops ethnl_pse_request_ops; +extern const struct ethnl_request_ops ethnl_rss_request_ops; extern const struct nla_policy ethnl_header_policy[ETHTOOL_A_HEADER_FLAGS + 1]; extern const struct nla_policy ethnl_header_policy_stats[ETHTOOL_A_HEADER_FLAGS + 1]; @@ -386,6 +387,7 @@ extern const struct nla_policy ethnl_module_get_policy[ETHTOOL_A_MODULE_HEADER + extern const struct nla_policy ethnl_module_set_policy[ETHTOOL_A_MODULE_POWER_MODE_POLICY + 1]; extern const struct nla_policy ethnl_pse_get_policy[ETHTOOL_A_PSE_HEADER + 1]; extern const struct nla_policy ethnl_pse_set_policy[ETHTOOL_A_PSE_MAX + 1]; +extern const struct nla_policy ethnl_rss_get_policy[ETHTOOL_A_RSS_CONTEXT + 1]; int ethnl_set_linkinfo(struct sk_buff *skb, struct genl_info *info); int ethnl_set_linkmodes(struct sk_buff *skb, struct genl_info *info); diff --git a/net/ethtool/rss.c b/net/ethtool/rss.c new file mode 100644 index 000000000000..be260ab34e58 --- /dev/null +++ b/net/ethtool/rss.c @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" + +struct rss_req_info { + struct ethnl_req_info base; + u32 rss_context; +}; + +struct rss_reply_data { + struct ethnl_reply_data base; + u32 indir_size; + u32 hkey_size; + u32 hfunc; + u32 *indir_table; + u8 *hkey; +}; + +#define RSS_REQINFO(__req_base) \ + container_of(__req_base, struct rss_req_info, base) + +#define RSS_REPDATA(__reply_base) \ + container_of(__reply_base, struct rss_reply_data, base) + +const struct nla_policy ethnl_rss_get_policy[] = { + [ETHTOOL_A_RSS_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_RSS_CONTEXT] = { .type = NLA_U32 }, +}; + +static int +rss_parse_request(struct ethnl_req_info *req_info, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rss_req_info *request = RSS_REQINFO(req_info); + + if (tb[ETHTOOL_A_RSS_CONTEXT]) + request->rss_context = nla_get_u32(tb[ETHTOOL_A_RSS_CONTEXT]); + + return 0; +} + +static int +rss_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, struct genl_info *info) +{ + struct rss_reply_data *data = RSS_REPDATA(reply_base); + struct rss_req_info *request = RSS_REQINFO(req_base); + struct net_device *dev = reply_base->dev; + const struct ethtool_ops *ops; + u32 total_size, indir_bytes; + u8 dev_hfunc = 0; + u8 *rss_config; + int ret; + + ops = dev->ethtool_ops; + if (!ops->get_rxfh) + return -EOPNOTSUPP; + + /* Some drivers don't handle rss_context */ + if (request->rss_context && !ops->get_rxfh_context) + return -EOPNOTSUPP; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + data->indir_size = 0; + data->hkey_size = 0; + if (ops->get_rxfh_indir_size) + data->indir_size = ops->get_rxfh_indir_size(dev); + if (ops->get_rxfh_key_size) + data->hkey_size = ops->get_rxfh_key_size(dev); + + indir_bytes = data->indir_size * sizeof(u32); + total_size = indir_bytes + data->hkey_size; + rss_config = kzalloc(total_size, GFP_KERNEL); + if (!rss_config) { + ret = -ENOMEM; + goto out_ops; + } + + if (data->indir_size) + data->indir_table = (u32 *)rss_config; + + if (data->hkey_size) + data->hkey = rss_config + indir_bytes; + + if (request->rss_context) + ret = ops->get_rxfh_context(dev, data->indir_table, data->hkey, + &dev_hfunc, request->rss_context); + else + ret = ops->get_rxfh(dev, data->indir_table, data->hkey, + &dev_hfunc); + + if (ret) + goto out_ops; + + data->hfunc = dev_hfunc; +out_ops: + ethnl_ops_complete(dev); + return ret; +} + +static int +rss_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct rss_reply_data *data = RSS_REPDATA(reply_base); + int len; + + len = nla_total_size(sizeof(u32)) + /* _RSS_HFUNC */ + nla_total_size(sizeof(u32) * data->indir_size) + /* _RSS_INDIR */ + nla_total_size(data->hkey_size); /* _RSS_HKEY */ + + return len; +} + +static int +rss_fill_reply(struct sk_buff *skb, const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct rss_reply_data *data = RSS_REPDATA(reply_base); + + if ((data->hfunc && + nla_put_u32(skb, ETHTOOL_A_RSS_HFUNC, data->hfunc)) || + (data->indir_size && + nla_put(skb, ETHTOOL_A_RSS_INDIR, + sizeof(u32) * data->indir_size, data->indir_table)) || + (data->hkey_size && + nla_put(skb, ETHTOOL_A_RSS_HKEY, data->hkey_size, data->hkey))) + return -EMSGSIZE; + + return 0; +} + +static void rss_cleanup_data(struct ethnl_reply_data *reply_base) +{ + const struct rss_reply_data *data = RSS_REPDATA(reply_base); + + kfree(data->indir_table); +} + +const struct ethnl_request_ops ethnl_rss_request_ops = { + .request_cmd = ETHTOOL_MSG_RSS_GET, + .reply_cmd = ETHTOOL_MSG_RSS_GET_REPLY, + .hdr_attr = ETHTOOL_A_RSS_HEADER, + .req_info_size = sizeof(struct rss_req_info), + .reply_data_size = sizeof(struct rss_reply_data), + + .parse_request = rss_parse_request, + .prepare_data = rss_prepare_data, + .reply_size = rss_reply_size, + .fill_reply = rss_fill_reply, + .cleanup_data = rss_cleanup_data, +}; diff --git a/net/hsr/hsr_debugfs.c b/net/hsr/hsr_debugfs.c index de476a417631..1a195efc79cd 100644 --- a/net/hsr/hsr_debugfs.c +++ b/net/hsr/hsr_debugfs.c @@ -9,7 +9,6 @@ #include <linux/module.h> #include <linux/errno.h> #include <linux/debugfs.h> -#include <linux/jhash.h> #include "hsr_main.h" #include "hsr_framereg.h" @@ -21,7 +20,6 @@ hsr_node_table_show(struct seq_file *sfp, void *data) { struct hsr_priv *priv = (struct hsr_priv *)sfp->private; struct hsr_node *node; - int i; seq_printf(sfp, "Node Table entries for (%s) device\n", (priv->prot_version == PRP_V1 ? "PRP" : "HSR")); @@ -33,28 +31,22 @@ hsr_node_table_show(struct seq_file *sfp, void *data) seq_puts(sfp, "DAN-H\n"); rcu_read_lock(); - - for (i = 0 ; i < priv->hash_buckets; i++) { - hlist_for_each_entry_rcu(node, &priv->node_db[i], mac_list) { - /* skip self node */ - if (hsr_addr_is_self(priv, node->macaddress_A)) - continue; - seq_printf(sfp, "%pM ", &node->macaddress_A[0]); - seq_printf(sfp, "%pM ", &node->macaddress_B[0]); - seq_printf(sfp, "%10lx, ", - node->time_in[HSR_PT_SLAVE_A]); - seq_printf(sfp, "%10lx, ", - node->time_in[HSR_PT_SLAVE_B]); - seq_printf(sfp, "%14x, ", node->addr_B_port); - - if (priv->prot_version == PRP_V1) - seq_printf(sfp, "%5x, %5x, %5x\n", - node->san_a, node->san_b, - (node->san_a == 0 && - node->san_b == 0)); - else - seq_printf(sfp, "%5x\n", 1); - } + list_for_each_entry_rcu(node, &priv->node_db, mac_list) { + /* skip self node */ + if (hsr_addr_is_self(priv, node->macaddress_A)) + continue; + seq_printf(sfp, "%pM ", &node->macaddress_A[0]); + seq_printf(sfp, "%pM ", &node->macaddress_B[0]); + seq_printf(sfp, "%10lx, ", node->time_in[HSR_PT_SLAVE_A]); + seq_printf(sfp, "%10lx, ", node->time_in[HSR_PT_SLAVE_B]); + seq_printf(sfp, "%14x, ", node->addr_B_port); + + if (priv->prot_version == PRP_V1) + seq_printf(sfp, "%5x, %5x, %5x\n", + node->san_a, node->san_b, + (node->san_a == 0 && node->san_b == 0)); + else + seq_printf(sfp, "%5x\n", 1); } rcu_read_unlock(); return 0; diff --git a/net/hsr/hsr_device.c b/net/hsr/hsr_device.c index 6ffef47e9be5..5a236aae2366 100644 --- a/net/hsr/hsr_device.c +++ b/net/hsr/hsr_device.c @@ -219,7 +219,9 @@ static netdev_tx_t hsr_dev_xmit(struct sk_buff *skb, struct net_device *dev) skb->dev = master->dev; skb_reset_mac_header(skb); skb_reset_mac_len(skb); + spin_lock_bh(&hsr->seqnr_lock); hsr_forward_skb(skb, master); + spin_unlock_bh(&hsr->seqnr_lock); } else { dev_core_stats_tx_dropped_inc(dev); dev_kfree_skb_any(skb); @@ -278,7 +280,6 @@ static void send_hsr_supervision_frame(struct hsr_port *master, __u8 type = HSR_TLV_LIFE_CHECK; struct hsr_sup_payload *hsr_sp; struct hsr_sup_tag *hsr_stag; - unsigned long irqflags; struct sk_buff *skb; *interval = msecs_to_jiffies(HSR_LIFE_CHECK_INTERVAL); @@ -299,7 +300,7 @@ static void send_hsr_supervision_frame(struct hsr_port *master, set_hsr_stag_HSR_ver(hsr_stag, hsr->prot_version); /* From HSRv1 on we have separate supervision sequence numbers. */ - spin_lock_irqsave(&master->hsr->seqnr_lock, irqflags); + spin_lock_bh(&hsr->seqnr_lock); if (hsr->prot_version > 0) { hsr_stag->sequence_nr = htons(hsr->sup_sequence_nr); hsr->sup_sequence_nr++; @@ -307,7 +308,6 @@ static void send_hsr_supervision_frame(struct hsr_port *master, hsr_stag->sequence_nr = htons(hsr->sequence_nr); hsr->sequence_nr++; } - spin_unlock_irqrestore(&master->hsr->seqnr_lock, irqflags); hsr_stag->tlv.HSR_TLV_type = type; /* TODO: Why 12 in HSRv0? */ @@ -318,11 +318,13 @@ static void send_hsr_supervision_frame(struct hsr_port *master, hsr_sp = skb_put(skb, sizeof(struct hsr_sup_payload)); ether_addr_copy(hsr_sp->macaddress_A, master->dev->dev_addr); - if (skb_put_padto(skb, ETH_ZLEN)) + if (skb_put_padto(skb, ETH_ZLEN)) { + spin_unlock_bh(&hsr->seqnr_lock); return; + } hsr_forward_skb(skb, master); - + spin_unlock_bh(&hsr->seqnr_lock); return; } @@ -332,7 +334,6 @@ static void send_prp_supervision_frame(struct hsr_port *master, struct hsr_priv *hsr = master->hsr; struct hsr_sup_payload *hsr_sp; struct hsr_sup_tag *hsr_stag; - unsigned long irqflags; struct sk_buff *skb; skb = hsr_init_skb(master); @@ -347,7 +348,7 @@ static void send_prp_supervision_frame(struct hsr_port *master, set_hsr_stag_HSR_ver(hsr_stag, (hsr->prot_version ? 1 : 0)); /* From HSRv1 on we have separate supervision sequence numbers. */ - spin_lock_irqsave(&master->hsr->seqnr_lock, irqflags); + spin_lock_bh(&hsr->seqnr_lock); hsr_stag->sequence_nr = htons(hsr->sup_sequence_nr); hsr->sup_sequence_nr++; hsr_stag->tlv.HSR_TLV_type = PRP_TLV_LIFE_CHECK_DD; @@ -358,13 +359,12 @@ static void send_prp_supervision_frame(struct hsr_port *master, ether_addr_copy(hsr_sp->macaddress_A, master->dev->dev_addr); if (skb_put_padto(skb, ETH_ZLEN)) { - spin_unlock_irqrestore(&master->hsr->seqnr_lock, irqflags); + spin_unlock_bh(&hsr->seqnr_lock); return; } - spin_unlock_irqrestore(&master->hsr->seqnr_lock, irqflags); - hsr_forward_skb(skb, master); + spin_unlock_bh(&hsr->seqnr_lock); } /* Announce (supervision frame) timer function @@ -444,7 +444,7 @@ void hsr_dev_setup(struct net_device *dev) dev->header_ops = &hsr_header_ops; dev->netdev_ops = &hsr_device_ops; SET_NETDEV_DEVTYPE(dev, &hsr_type); - dev->priv_flags |= IFF_NO_QUEUE; + dev->priv_flags |= IFF_NO_QUEUE | IFF_DISABLE_NETPOLL; dev->needs_free_netdev = true; @@ -485,16 +485,11 @@ int hsr_dev_finalize(struct net_device *hsr_dev, struct net_device *slave[2], { bool unregister = false; struct hsr_priv *hsr; - int res, i; + int res; hsr = netdev_priv(hsr_dev); INIT_LIST_HEAD(&hsr->ports); - INIT_HLIST_HEAD(&hsr->self_node_db); - hsr->hash_buckets = HSR_HSIZE; - get_random_bytes(&hsr->hash_seed, sizeof(hsr->hash_seed)); - for (i = 0; i < hsr->hash_buckets; i++) - INIT_HLIST_HEAD(&hsr->node_db[i]); - + INIT_LIST_HEAD(&hsr->node_db); spin_lock_init(&hsr->list_lock); eth_hw_addr_set(hsr_dev, slave[0]->dev_addr); diff --git a/net/hsr/hsr_forward.c b/net/hsr/hsr_forward.c index a50429a62f74..629daacc9607 100644 --- a/net/hsr/hsr_forward.c +++ b/net/hsr/hsr_forward.c @@ -351,17 +351,18 @@ static void hsr_deliver_master(struct sk_buff *skb, struct net_device *dev, struct hsr_node *node_src) { bool was_multicast_frame; - int res; + int res, recv_len; was_multicast_frame = (skb->pkt_type == PACKET_MULTICAST); hsr_addr_subst_source(node_src, skb); skb_pull(skb, ETH_HLEN); + recv_len = skb->len; res = netif_rx(skb); if (res == NET_RX_DROP) { dev->stats.rx_dropped++; } else { dev->stats.rx_packets++; - dev->stats.rx_bytes += skb->len; + dev->stats.rx_bytes += recv_len; if (was_multicast_frame) dev->stats.multicast++; } @@ -499,7 +500,6 @@ static void handle_std_frame(struct sk_buff *skb, { struct hsr_port *port = frame->port_rcv; struct hsr_priv *hsr = port->hsr; - unsigned long irqflags; frame->skb_hsr = NULL; frame->skb_prp = NULL; @@ -509,10 +509,9 @@ static void handle_std_frame(struct sk_buff *skb, frame->is_from_san = true; } else { /* Sequence nr for the master node */ - spin_lock_irqsave(&hsr->seqnr_lock, irqflags); + lockdep_assert_held(&hsr->seqnr_lock); frame->sequence_nr = hsr->sequence_nr; hsr->sequence_nr++; - spin_unlock_irqrestore(&hsr->seqnr_lock, irqflags); } } @@ -570,23 +569,20 @@ static int fill_frame_info(struct hsr_frame_info *frame, struct ethhdr *ethhdr; __be16 proto; int ret; - u32 hash; /* Check if skb contains ethhdr */ if (skb->mac_len < sizeof(struct ethhdr)) return -EINVAL; memset(frame, 0, sizeof(*frame)); - - ethhdr = (struct ethhdr *)skb_mac_header(skb); - hash = hsr_mac_hash(port->hsr, ethhdr->h_source); frame->is_supervision = is_supervision_frame(port->hsr, skb); - frame->node_src = hsr_get_node(port, &hsr->node_db[hash], skb, + frame->node_src = hsr_get_node(port, &hsr->node_db, skb, frame->is_supervision, port->type); if (!frame->node_src) return -1; /* Unknown node and !is_supervision, or no mem */ + ethhdr = (struct ethhdr *)skb_mac_header(skb); frame->is_vlan = false; proto = ethhdr->h_proto; @@ -616,11 +612,13 @@ void hsr_forward_skb(struct sk_buff *skb, struct hsr_port *port) { struct hsr_frame_info frame; + rcu_read_lock(); if (fill_frame_info(&frame, skb, port) < 0) goto out_drop; hsr_register_frame_in(frame.node_src, port, frame.sequence_nr); hsr_forward_do(&frame); + rcu_read_unlock(); /* Gets called for ingress frames as well as egress from master port. * So check and increment stats for master port only here. */ @@ -635,6 +633,7 @@ void hsr_forward_skb(struct sk_buff *skb, struct hsr_port *port) return; out_drop: + rcu_read_unlock(); port->dev->stats.tx_dropped++; kfree_skb(skb); } diff --git a/net/hsr/hsr_framereg.c b/net/hsr/hsr_framereg.c index 584e21788799..00db74d96583 100644 --- a/net/hsr/hsr_framereg.c +++ b/net/hsr/hsr_framereg.c @@ -15,37 +15,10 @@ #include <linux/etherdevice.h> #include <linux/slab.h> #include <linux/rculist.h> -#include <linux/jhash.h> #include "hsr_main.h" #include "hsr_framereg.h" #include "hsr_netlink.h" -#ifdef CONFIG_LOCKDEP -int lockdep_hsr_is_held(spinlock_t *lock) -{ - return lockdep_is_held(lock); -} -#endif - -u32 hsr_mac_hash(struct hsr_priv *hsr, const unsigned char *addr) -{ - u32 hash = jhash(addr, ETH_ALEN, hsr->hash_seed); - - return reciprocal_scale(hash, hsr->hash_buckets); -} - -struct hsr_node *hsr_node_get_first(struct hlist_head *head, spinlock_t *lock) -{ - struct hlist_node *first; - - first = rcu_dereference_bh_check(hlist_first_rcu(head), - lockdep_hsr_is_held(lock)); - if (first) - return hlist_entry(first, struct hsr_node, mac_list); - - return NULL; -} - /* seq_nr_after(a, b) - return true if a is after (higher in sequence than) b, * false otherwise. */ @@ -65,30 +38,32 @@ static bool seq_nr_after(u16 a, u16 b) bool hsr_addr_is_self(struct hsr_priv *hsr, unsigned char *addr) { - struct hsr_node *node; + struct hsr_self_node *sn; + bool ret = false; - node = hsr_node_get_first(&hsr->self_node_db, &hsr->list_lock); - if (!node) { + rcu_read_lock(); + sn = rcu_dereference(hsr->self_node); + if (!sn) { WARN_ONCE(1, "HSR: No self node\n"); - return false; + goto out; } - if (ether_addr_equal(addr, node->macaddress_A)) - return true; - if (ether_addr_equal(addr, node->macaddress_B)) - return true; - - return false; + if (ether_addr_equal(addr, sn->macaddress_A) || + ether_addr_equal(addr, sn->macaddress_B)) + ret = true; +out: + rcu_read_unlock(); + return ret; } /* Search for mac entry. Caller must hold rcu read lock. */ -static struct hsr_node *find_node_by_addr_A(struct hlist_head *node_db, +static struct hsr_node *find_node_by_addr_A(struct list_head *node_db, const unsigned char addr[ETH_ALEN]) { struct hsr_node *node; - hlist_for_each_entry_rcu(node, node_db, mac_list) { + list_for_each_entry_rcu(node, node_db, mac_list) { if (ether_addr_equal(node->macaddress_A, addr)) return node; } @@ -96,58 +71,51 @@ static struct hsr_node *find_node_by_addr_A(struct hlist_head *node_db, return NULL; } -/* Helper for device init; the self_node_db is used in hsr_rcv() to recognize +/* Helper for device init; the self_node is used in hsr_rcv() to recognize * frames from self that's been looped over the HSR ring. */ int hsr_create_self_node(struct hsr_priv *hsr, const unsigned char addr_a[ETH_ALEN], const unsigned char addr_b[ETH_ALEN]) { - struct hlist_head *self_node_db = &hsr->self_node_db; - struct hsr_node *node, *oldnode; + struct hsr_self_node *sn, *old; - node = kmalloc(sizeof(*node), GFP_KERNEL); - if (!node) + sn = kmalloc(sizeof(*sn), GFP_KERNEL); + if (!sn) return -ENOMEM; - ether_addr_copy(node->macaddress_A, addr_a); - ether_addr_copy(node->macaddress_B, addr_b); + ether_addr_copy(sn->macaddress_A, addr_a); + ether_addr_copy(sn->macaddress_B, addr_b); spin_lock_bh(&hsr->list_lock); - oldnode = hsr_node_get_first(self_node_db, &hsr->list_lock); - if (oldnode) { - hlist_replace_rcu(&oldnode->mac_list, &node->mac_list); - spin_unlock_bh(&hsr->list_lock); - kfree_rcu(oldnode, rcu_head); - } else { - hlist_add_tail_rcu(&node->mac_list, self_node_db); - spin_unlock_bh(&hsr->list_lock); - } + old = rcu_replace_pointer(hsr->self_node, sn, + lockdep_is_held(&hsr->list_lock)); + spin_unlock_bh(&hsr->list_lock); + if (old) + kfree_rcu(old, rcu_head); return 0; } void hsr_del_self_node(struct hsr_priv *hsr) { - struct hlist_head *self_node_db = &hsr->self_node_db; - struct hsr_node *node; + struct hsr_self_node *old; spin_lock_bh(&hsr->list_lock); - node = hsr_node_get_first(self_node_db, &hsr->list_lock); - if (node) { - hlist_del_rcu(&node->mac_list); - kfree_rcu(node, rcu_head); - } + old = rcu_replace_pointer(hsr->self_node, NULL, + lockdep_is_held(&hsr->list_lock)); spin_unlock_bh(&hsr->list_lock); + if (old) + kfree_rcu(old, rcu_head); } -void hsr_del_nodes(struct hlist_head *node_db) +void hsr_del_nodes(struct list_head *node_db) { struct hsr_node *node; - struct hlist_node *tmp; + struct hsr_node *tmp; - hlist_for_each_entry_safe(node, tmp, node_db, mac_list) - kfree_rcu(node, rcu_head); + list_for_each_entry_safe(node, tmp, node_db, mac_list) + kfree(node); } void prp_handle_san_frame(bool san, enum hsr_port_type port, @@ -168,7 +136,7 @@ void prp_handle_san_frame(bool san, enum hsr_port_type port, * originating from the newly added node. */ static struct hsr_node *hsr_add_node(struct hsr_priv *hsr, - struct hlist_head *node_db, + struct list_head *node_db, unsigned char addr[], u16 seq_out, bool san, enum hsr_port_type rx_port) @@ -182,6 +150,7 @@ static struct hsr_node *hsr_add_node(struct hsr_priv *hsr, return NULL; ether_addr_copy(new_node->macaddress_A, addr); + spin_lock_init(&new_node->seq_out_lock); /* We are only interested in time diffs here, so use current jiffies * as initialization. (0 could trigger an spurious ring error warning). @@ -198,14 +167,14 @@ static struct hsr_node *hsr_add_node(struct hsr_priv *hsr, hsr->proto_ops->handle_san_frame(san, rx_port, new_node); spin_lock_bh(&hsr->list_lock); - hlist_for_each_entry_rcu(node, node_db, mac_list, - lockdep_hsr_is_held(&hsr->list_lock)) { + list_for_each_entry_rcu(node, node_db, mac_list, + lockdep_is_held(&hsr->list_lock)) { if (ether_addr_equal(node->macaddress_A, addr)) goto out; if (ether_addr_equal(node->macaddress_B, addr)) goto out; } - hlist_add_tail_rcu(&new_node->mac_list, node_db); + list_add_tail_rcu(&new_node->mac_list, node_db); spin_unlock_bh(&hsr->list_lock); return new_node; out: @@ -225,7 +194,7 @@ void prp_update_san_info(struct hsr_node *node, bool is_sup) /* Get the hsr_node from which 'skb' was sent. */ -struct hsr_node *hsr_get_node(struct hsr_port *port, struct hlist_head *node_db, +struct hsr_node *hsr_get_node(struct hsr_port *port, struct list_head *node_db, struct sk_buff *skb, bool is_sup, enum hsr_port_type rx_port) { @@ -241,7 +210,7 @@ struct hsr_node *hsr_get_node(struct hsr_port *port, struct hlist_head *node_db, ethhdr = (struct ethhdr *)skb_mac_header(skb); - hlist_for_each_entry_rcu(node, node_db, mac_list) { + list_for_each_entry_rcu(node, node_db, mac_list) { if (ether_addr_equal(node->macaddress_A, ethhdr->h_source)) { if (hsr->proto_ops->update_san_info) hsr->proto_ops->update_san_info(node, is_sup); @@ -291,12 +260,11 @@ void hsr_handle_sup_frame(struct hsr_frame_info *frame) struct hsr_sup_tlv *hsr_sup_tlv; struct hsr_node *node_real; struct sk_buff *skb = NULL; - struct hlist_head *node_db; + struct list_head *node_db; struct ethhdr *ethhdr; int i; unsigned int pull_size = 0; unsigned int total_pull_size = 0; - u32 hash; /* Here either frame->skb_hsr or frame->skb_prp should be * valid as supervision frame always will have protocol @@ -334,13 +302,11 @@ void hsr_handle_sup_frame(struct hsr_frame_info *frame) hsr_sp = (struct hsr_sup_payload *)skb->data; /* Merge node_curr (registered on macaddress_B) into node_real */ - node_db = port_rcv->hsr->node_db; - hash = hsr_mac_hash(hsr, hsr_sp->macaddress_A); - node_real = find_node_by_addr_A(&node_db[hash], hsr_sp->macaddress_A); + node_db = &port_rcv->hsr->node_db; + node_real = find_node_by_addr_A(node_db, hsr_sp->macaddress_A); if (!node_real) /* No frame received from AddrA of this node yet */ - node_real = hsr_add_node(hsr, &node_db[hash], - hsr_sp->macaddress_A, + node_real = hsr_add_node(hsr, node_db, hsr_sp->macaddress_A, HSR_SEQNR_START - 1, true, port_rcv->type); if (!node_real) @@ -374,14 +340,14 @@ void hsr_handle_sup_frame(struct hsr_frame_info *frame) hsr_sp = (struct hsr_sup_payload *)skb->data; /* Check if redbox mac and node mac are equal. */ - if (!ether_addr_equal(node_real->macaddress_A, - hsr_sp->macaddress_A)) { + if (!ether_addr_equal(node_real->macaddress_A, hsr_sp->macaddress_A)) { /* This is a redbox supervision frame for a VDAN! */ goto done; } } ether_addr_copy(node_real->macaddress_B, ethhdr->h_source); + spin_lock_bh(&node_real->seq_out_lock); for (i = 0; i < HSR_PT_PORTS; i++) { if (!node_curr->time_in_stale[i] && time_after(node_curr->time_in[i], node_real->time_in[i])) { @@ -392,12 +358,16 @@ void hsr_handle_sup_frame(struct hsr_frame_info *frame) if (seq_nr_after(node_curr->seq_out[i], node_real->seq_out[i])) node_real->seq_out[i] = node_curr->seq_out[i]; } + spin_unlock_bh(&node_real->seq_out_lock); node_real->addr_B_port = port_rcv->type; spin_lock_bh(&hsr->list_lock); - hlist_del_rcu(&node_curr->mac_list); + if (!node_curr->removed) { + list_del_rcu(&node_curr->mac_list); + node_curr->removed = true; + kfree_rcu(node_curr, rcu_head); + } spin_unlock_bh(&hsr->list_lock); - kfree_rcu(node_curr, rcu_head); done: /* Push back here */ @@ -433,7 +403,6 @@ void hsr_addr_subst_dest(struct hsr_node *node_src, struct sk_buff *skb, struct hsr_port *port) { struct hsr_node *node_dst; - u32 hash; if (!skb_mac_header_was_set(skb)) { WARN_ONCE(1, "%s: Mac header not set\n", __func__); @@ -443,8 +412,7 @@ void hsr_addr_subst_dest(struct hsr_node *node_src, struct sk_buff *skb, if (!is_unicast_ether_addr(eth_hdr(skb)->h_dest)) return; - hash = hsr_mac_hash(port->hsr, eth_hdr(skb)->h_dest); - node_dst = find_node_by_addr_A(&port->hsr->node_db[hash], + node_dst = find_node_by_addr_A(&port->hsr->node_db, eth_hdr(skb)->h_dest); if (!node_dst) { if (net_ratelimit()) @@ -484,13 +452,17 @@ void hsr_register_frame_in(struct hsr_node *node, struct hsr_port *port, int hsr_register_frame_out(struct hsr_port *port, struct hsr_node *node, u16 sequence_nr) { + spin_lock_bh(&node->seq_out_lock); if (seq_nr_before_or_eq(sequence_nr, node->seq_out[port->type]) && time_is_after_jiffies(node->time_out[port->type] + - msecs_to_jiffies(HSR_ENTRY_FORGET_TIME))) + msecs_to_jiffies(HSR_ENTRY_FORGET_TIME))) { + spin_unlock_bh(&node->seq_out_lock); return 1; + } node->time_out[port->type] = jiffies; node->seq_out[port->type] = sequence_nr; + spin_unlock_bh(&node->seq_out_lock); return 0; } @@ -520,71 +492,60 @@ static struct hsr_port *get_late_port(struct hsr_priv *hsr, void hsr_prune_nodes(struct timer_list *t) { struct hsr_priv *hsr = from_timer(hsr, t, prune_timer); - struct hlist_node *tmp; struct hsr_node *node; + struct hsr_node *tmp; struct hsr_port *port; unsigned long timestamp; unsigned long time_a, time_b; - int i; spin_lock_bh(&hsr->list_lock); + list_for_each_entry_safe(node, tmp, &hsr->node_db, mac_list) { + /* Don't prune own node. Neither time_in[HSR_PT_SLAVE_A] + * nor time_in[HSR_PT_SLAVE_B], will ever be updated for + * the master port. Thus the master node will be repeatedly + * pruned leading to packet loss. + */ + if (hsr_addr_is_self(hsr, node->macaddress_A)) + continue; + + /* Shorthand */ + time_a = node->time_in[HSR_PT_SLAVE_A]; + time_b = node->time_in[HSR_PT_SLAVE_B]; + + /* Check for timestamps old enough to risk wrap-around */ + if (time_after(jiffies, time_a + MAX_JIFFY_OFFSET / 2)) + node->time_in_stale[HSR_PT_SLAVE_A] = true; + if (time_after(jiffies, time_b + MAX_JIFFY_OFFSET / 2)) + node->time_in_stale[HSR_PT_SLAVE_B] = true; + + /* Get age of newest frame from node. + * At least one time_in is OK here; nodes get pruned long + * before both time_ins can get stale + */ + timestamp = time_a; + if (node->time_in_stale[HSR_PT_SLAVE_A] || + (!node->time_in_stale[HSR_PT_SLAVE_B] && + time_after(time_b, time_a))) + timestamp = time_b; + + /* Warn of ring error only as long as we get frames at all */ + if (time_is_after_jiffies(timestamp + + msecs_to_jiffies(1.5 * MAX_SLAVE_DIFF))) { + rcu_read_lock(); + port = get_late_port(hsr, node); + if (port) + hsr_nl_ringerror(hsr, node->macaddress_A, port); + rcu_read_unlock(); + } - for (i = 0; i < hsr->hash_buckets; i++) { - hlist_for_each_entry_safe(node, tmp, &hsr->node_db[i], - mac_list) { - /* Don't prune own node. - * Neither time_in[HSR_PT_SLAVE_A] - * nor time_in[HSR_PT_SLAVE_B], will ever be updated - * for the master port. Thus the master node will be - * repeatedly pruned leading to packet loss. - */ - if (hsr_addr_is_self(hsr, node->macaddress_A)) - continue; - - /* Shorthand */ - time_a = node->time_in[HSR_PT_SLAVE_A]; - time_b = node->time_in[HSR_PT_SLAVE_B]; - - /* Check for timestamps old enough to - * risk wrap-around - */ - if (time_after(jiffies, time_a + MAX_JIFFY_OFFSET / 2)) - node->time_in_stale[HSR_PT_SLAVE_A] = true; - if (time_after(jiffies, time_b + MAX_JIFFY_OFFSET / 2)) - node->time_in_stale[HSR_PT_SLAVE_B] = true; - - /* Get age of newest frame from node. - * At least one time_in is OK here; nodes get pruned - * long before both time_ins can get stale - */ - timestamp = time_a; - if (node->time_in_stale[HSR_PT_SLAVE_A] || - (!node->time_in_stale[HSR_PT_SLAVE_B] && - time_after(time_b, time_a))) - timestamp = time_b; - - /* Warn of ring error only as long as we get - * frames at all - */ - if (time_is_after_jiffies(timestamp + - msecs_to_jiffies(1.5 * MAX_SLAVE_DIFF))) { - rcu_read_lock(); - port = get_late_port(hsr, node); - if (port) - hsr_nl_ringerror(hsr, - node->macaddress_A, - port); - rcu_read_unlock(); - } - - /* Prune old entries */ - if (time_is_before_jiffies(timestamp + - msecs_to_jiffies(HSR_NODE_FORGET_TIME))) { - hsr_nl_nodedown(hsr, node->macaddress_A); - hlist_del_rcu(&node->mac_list); - /* Note that we need to free this - * entry later: - */ + /* Prune old entries */ + if (time_is_before_jiffies(timestamp + + msecs_to_jiffies(HSR_NODE_FORGET_TIME))) { + hsr_nl_nodedown(hsr, node->macaddress_A); + if (!node->removed) { + list_del_rcu(&node->mac_list); + node->removed = true; + /* Note that we need to free this entry later: */ kfree_rcu(node, rcu_head); } } @@ -600,20 +561,17 @@ void *hsr_get_next_node(struct hsr_priv *hsr, void *_pos, unsigned char addr[ETH_ALEN]) { struct hsr_node *node; - u32 hash; - - hash = hsr_mac_hash(hsr, addr); if (!_pos) { - node = hsr_node_get_first(&hsr->node_db[hash], - &hsr->list_lock); + node = list_first_or_null_rcu(&hsr->node_db, + struct hsr_node, mac_list); if (node) ether_addr_copy(addr, node->macaddress_A); return node; } node = _pos; - hlist_for_each_entry_continue_rcu(node, mac_list) { + list_for_each_entry_continue_rcu(node, &hsr->node_db, mac_list) { ether_addr_copy(addr, node->macaddress_A); return node; } @@ -633,11 +591,8 @@ int hsr_get_node_data(struct hsr_priv *hsr, struct hsr_node *node; struct hsr_port *port; unsigned long tdiff; - u32 hash; - - hash = hsr_mac_hash(hsr, addr); - node = find_node_by_addr_A(&hsr->node_db[hash], addr); + node = find_node_by_addr_A(&hsr->node_db, addr); if (!node) return -ENOENT; diff --git a/net/hsr/hsr_framereg.h b/net/hsr/hsr_framereg.h index f3762e9e42b5..b23556251d62 100644 --- a/net/hsr/hsr_framereg.h +++ b/net/hsr/hsr_framereg.h @@ -28,17 +28,9 @@ struct hsr_frame_info { bool is_from_san; }; -#ifdef CONFIG_LOCKDEP -int lockdep_hsr_is_held(spinlock_t *lock); -#else -#define lockdep_hsr_is_held(lock) 1 -#endif - -u32 hsr_mac_hash(struct hsr_priv *hsr, const unsigned char *addr); -struct hsr_node *hsr_node_get_first(struct hlist_head *head, spinlock_t *lock); void hsr_del_self_node(struct hsr_priv *hsr); -void hsr_del_nodes(struct hlist_head *node_db); -struct hsr_node *hsr_get_node(struct hsr_port *port, struct hlist_head *node_db, +void hsr_del_nodes(struct list_head *node_db); +struct hsr_node *hsr_get_node(struct hsr_port *port, struct list_head *node_db, struct sk_buff *skb, bool is_sup, enum hsr_port_type rx_port); void hsr_handle_sup_frame(struct hsr_frame_info *frame); @@ -76,7 +68,9 @@ void prp_handle_san_frame(bool san, enum hsr_port_type port, void prp_update_san_info(struct hsr_node *node, bool is_sup); struct hsr_node { - struct hlist_node mac_list; + struct list_head mac_list; + /* Protect R/W access to seq_out */ + spinlock_t seq_out_lock; unsigned char macaddress_A[ETH_ALEN]; unsigned char macaddress_B[ETH_ALEN]; /* Local slave through which AddrB frames are received from this node */ @@ -88,6 +82,7 @@ struct hsr_node { bool san_a; bool san_b; u16 seq_out[HSR_PT_PORTS]; + bool removed; struct rcu_head rcu_head; }; diff --git a/net/hsr/hsr_main.h b/net/hsr/hsr_main.h index b158ba409f9a..5584c80a5c79 100644 --- a/net/hsr/hsr_main.h +++ b/net/hsr/hsr_main.h @@ -47,9 +47,6 @@ #define HSR_V1_SUP_LSDUSIZE 52 -#define HSR_HSIZE_SHIFT 8 -#define HSR_HSIZE BIT(HSR_HSIZE_SHIFT) - /* The helper functions below assumes that 'path' occupies the 4 most * significant bits of the 16-bit field shared by 'path' and 'LSDU_size' (or * equivalently, the 4 most significant bits of HSR tag byte 14). @@ -185,11 +182,17 @@ struct hsr_proto_ops { void (*update_san_info)(struct hsr_node *node, bool is_sup); }; +struct hsr_self_node { + unsigned char macaddress_A[ETH_ALEN]; + unsigned char macaddress_B[ETH_ALEN]; + struct rcu_head rcu_head; +}; + struct hsr_priv { struct rcu_head rcu_head; struct list_head ports; - struct hlist_head node_db[HSR_HSIZE]; /* Known HSR nodes */ - struct hlist_head self_node_db; /* MACs of slaves */ + struct list_head node_db; /* Known HSR nodes */ + struct hsr_self_node __rcu *self_node; /* MACs of slaves */ struct timer_list announce_timer; /* Supervision frame dispatch */ struct timer_list prune_timer; int announce_count; @@ -199,8 +202,6 @@ struct hsr_priv { spinlock_t seqnr_lock; /* locking for sequence_nr */ spinlock_t list_lock; /* locking for node list */ struct hsr_proto_ops *proto_ops; - u32 hash_buckets; - u32 hash_seed; #define PRP_LAN_ID 0x5 /* 0x1010 for A and 0x1011 for B. Bit 0 is set * based on SLAVE_A or SLAVE_B */ diff --git a/net/hsr/hsr_netlink.c b/net/hsr/hsr_netlink.c index 7174a9092900..78fe40eb9f01 100644 --- a/net/hsr/hsr_netlink.c +++ b/net/hsr/hsr_netlink.c @@ -105,7 +105,6 @@ static int hsr_newlink(struct net *src_net, struct net_device *dev, static void hsr_dellink(struct net_device *dev, struct list_head *head) { struct hsr_priv *hsr = netdev_priv(dev); - int i; del_timer_sync(&hsr->prune_timer); del_timer_sync(&hsr->announce_timer); @@ -114,8 +113,7 @@ static void hsr_dellink(struct net_device *dev, struct list_head *head) hsr_del_ports(hsr); hsr_del_self_node(hsr); - for (i = 0; i < hsr->hash_buckets; i++) - hsr_del_nodes(&hsr->node_db[i]); + hsr_del_nodes(&hsr->node_db); unregister_netdevice_queue(dev, head); } diff --git a/net/ieee802154/core.c b/net/ieee802154/core.c index de259b5170ab..57546e07e06a 100644 --- a/net/ieee802154/core.c +++ b/net/ieee802154/core.c @@ -129,6 +129,9 @@ wpan_phy_new(const struct cfg802154_ops *ops, size_t priv_size) wpan_phy_net_set(&rdev->wpan_phy, &init_net); init_waitqueue_head(&rdev->dev_wait); + init_waitqueue_head(&rdev->wpan_phy.sync_txq); + + spin_lock_init(&rdev->wpan_phy.queue_lock); return &rdev->wpan_phy; } diff --git a/net/ieee802154/nl802154.c b/net/ieee802154/nl802154.c index 38c4f3cb010e..248ad5e46969 100644 --- a/net/ieee802154/nl802154.c +++ b/net/ieee802154/nl802154.c @@ -26,10 +26,12 @@ static struct genl_family nl802154_fam; /* multicast groups */ enum nl802154_multicast_groups { NL802154_MCGRP_CONFIG, + NL802154_MCGRP_SCAN, }; static const struct genl_multicast_group nl802154_mcgrps[] = { [NL802154_MCGRP_CONFIG] = { .name = "config", }, + [NL802154_MCGRP_SCAN] = { .name = "scan", }, }; /* returns ERR_PTR values */ @@ -216,6 +218,9 @@ static const struct nla_policy nl802154_policy[NL802154_ATTR_MAX+1] = { [NL802154_ATTR_PID] = { .type = NLA_U32 }, [NL802154_ATTR_NETNS_FD] = { .type = NLA_U32 }, + + [NL802154_ATTR_COORDINATOR] = { .type = NLA_NESTED }, + #ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL [NL802154_ATTR_SEC_ENABLED] = { .type = NLA_U8, }, [NL802154_ATTR_SEC_OUT_LEVEL] = { .type = NLA_U32, }, @@ -1281,6 +1286,104 @@ static int nl802154_wpan_phy_netns(struct sk_buff *skb, struct genl_info *info) return err; } +static int nl802154_prep_scan_event_msg(struct sk_buff *msg, + struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + u32 portid, u32 seq, int flags, u8 cmd, + struct ieee802154_coord_desc *desc) +{ + struct nlattr *nla; + void *hdr; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_WPAN_PHY, rdev->wpan_phy_idx)) + goto nla_put_failure; + + if (wpan_dev->netdev && + nla_put_u32(msg, NL802154_ATTR_IFINDEX, wpan_dev->netdev->ifindex)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NL802154_ATTR_WPAN_DEV, + wpan_dev_id(wpan_dev), NL802154_ATTR_PAD)) + goto nla_put_failure; + + nla = nla_nest_start_noflag(msg, NL802154_ATTR_COORDINATOR); + if (!nla) + goto nla_put_failure; + + if (nla_put(msg, NL802154_COORD_PANID, IEEE802154_PAN_ID_LEN, + &desc->addr.pan_id)) + goto nla_put_failure; + + if (desc->addr.mode == IEEE802154_ADDR_SHORT) { + if (nla_put(msg, NL802154_COORD_ADDR, + IEEE802154_SHORT_ADDR_LEN, + &desc->addr.short_addr)) + goto nla_put_failure; + } else { + if (nla_put(msg, NL802154_COORD_ADDR, + IEEE802154_EXTENDED_ADDR_LEN, + &desc->addr.extended_addr)) + goto nla_put_failure; + } + + if (nla_put_u8(msg, NL802154_COORD_CHANNEL, desc->channel)) + goto nla_put_failure; + + if (nla_put_u8(msg, NL802154_COORD_PAGE, desc->page)) + goto nla_put_failure; + + if (nla_put_u16(msg, NL802154_COORD_SUPERFRAME_SPEC, + desc->superframe_spec)) + goto nla_put_failure; + + if (nla_put_u8(msg, NL802154_COORD_LINK_QUALITY, desc->link_quality)) + goto nla_put_failure; + + if (desc->gts_permit && nla_put_flag(msg, NL802154_COORD_GTS_PERMIT)) + goto nla_put_failure; + + /* TODO: NL802154_COORD_PAYLOAD_DATA if any */ + + nla_nest_end(msg, nla); + + genlmsg_end(msg, hdr); + + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + + return -EMSGSIZE; +} + +int nl802154_scan_event(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + struct ieee802154_coord_desc *desc) +{ + struct cfg802154_registered_device *rdev = wpan_phy_to_rdev(wpan_phy); + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + ret = nl802154_prep_scan_event_msg(msg, rdev, wpan_dev, 0, 0, 0, + NL802154_CMD_SCAN_EVENT, + desc); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + return genlmsg_multicast_netns(&nl802154_fam, wpan_phy_net(wpan_phy), + msg, 0, NL802154_MCGRP_SCAN, GFP_ATOMIC); +} +EXPORT_SYMBOL_GPL(nl802154_scan_event); + #ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL static const struct nla_policy nl802154_dev_addr_policy[NL802154_DEV_ADDR_ATTR_MAX + 1] = { [NL802154_DEV_ADDR_ATTR_PAN_ID] = { .type = NLA_U16 }, @@ -2157,7 +2260,8 @@ static int nl802154_del_llsec_seclevel(struct sk_buff *skb, #define NL802154_FLAG_CHECK_NETDEV_UP 0x08 #define NL802154_FLAG_NEED_WPAN_DEV 0x10 -static int nl802154_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, +static int nl802154_pre_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { struct cfg802154_registered_device *rdev; @@ -2219,7 +2323,8 @@ static int nl802154_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, return 0; } -static void nl802154_post_doit(const struct genl_ops *ops, struct sk_buff *skb, +static void nl802154_post_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { if (info->user_ptr[1]) { diff --git a/net/ieee802154/nl802154.h b/net/ieee802154/nl802154.h index 8c4b6d08954c..89b805500032 100644 --- a/net/ieee802154/nl802154.h +++ b/net/ieee802154/nl802154.h @@ -4,5 +4,7 @@ int nl802154_init(void); void nl802154_exit(void); +int nl802154_scan_event(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + struct ieee802154_coord_desc *desc); #endif /* __IEEE802154_NL802154_H */ diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index e983bb0c5012..2dfb12230f08 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -402,6 +402,16 @@ config INET_IPCOMP If unsure, say Y. +config INET_TABLE_PERTURB_ORDER + int "INET: Source port perturbation table size (as power of 2)" if EXPERT + default 16 + help + Source port perturbation table size (as power of 2) for + RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm. + + The default is almost always what you want. + Only change this if you know what you are doing. + config INET_XFRM_TUNNEL tristate select INET_TUNNEL diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index bbdd9c44f14e..af7d2cf490fb 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -10,7 +10,7 @@ obj-y := route.o inetpeer.o protocol.o \ tcp.o tcp_input.o tcp_output.o tcp_timer.o tcp_ipv4.o \ tcp_minisocks.o tcp_cong.o tcp_metrics.o tcp_fastopen.o \ tcp_rate.o tcp_recovery.o tcp_ulp.o \ - tcp_offload.o datagram.o raw.o udp.o udplite.o \ + tcp_offload.o tcp_plb.o datagram.o raw.o udp.o udplite.o \ udp_offload.o arp.o icmp.o devinet.o af_inet.o igmp.o \ fib_frontend.o fib_semantics.o fib_trie.o fib_notifier.o \ inet_fragment.o ping.o ip_tunnel_core.o gre_offload.o \ diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 4728087c42a5..6c0ec2789943 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -522,9 +522,9 @@ int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, /* Make sure we are allowed to bind here. */ if (snum || !(inet->bind_address_no_port || (flags & BIND_FORCE_ADDRESS_NO_PORT))) { - if (sk->sk_prot->get_port(sk, snum)) { + err = sk->sk_prot->get_port(sk, snum); + if (err) { inet->inet_saddr = inet->inet_rcv_saddr = 0; - err = -EADDRINUSE; goto out_release_sock; } if (!(flags & BIND_FROM_BPF)) { @@ -1230,7 +1230,6 @@ EXPORT_SYMBOL(inet_unregister_protosw); static int inet_sk_reselect_saddr(struct sock *sk) { - struct inet_bind_hashbucket *prev_addr_hashbucket; struct inet_sock *inet = inet_sk(sk); __be32 old_saddr = inet->inet_saddr; __be32 daddr = inet->inet_daddr; @@ -1260,16 +1259,8 @@ static int inet_sk_reselect_saddr(struct sock *sk) return 0; } - prev_addr_hashbucket = - inet_bhashfn_portaddr(tcp_or_dccp_get_hashinfo(sk), sk, - sock_net(sk), inet->inet_num); - - inet->inet_saddr = inet->inet_rcv_saddr = new_saddr; - - err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); + err = inet_bhash2_update_saddr(sk, &new_saddr, AF_INET); if (err) { - inet->inet_saddr = old_saddr; - inet->inet_rcv_saddr = old_saddr; ip_rt_put(rt); return err; } @@ -1674,6 +1665,7 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family, if (rc == 0) { *sk = sock->sk; (*sk)->sk_allocation = GFP_ATOMIC; + (*sk)->sk_use_task_frag = false; /* * Unhash it so that IP input processing does not even see it, * we do not wish this socket to see incoming packets. @@ -1708,9 +1700,9 @@ u64 snmp_get_cpu_field64(void __percpu *mib, int cpu, int offt, bhptr = per_cpu_ptr(mib, cpu); syncp = (struct u64_stats_sync *)(bhptr + syncp_offset); do { - start = u64_stats_fetch_begin_irq(syncp); + start = u64_stats_fetch_begin(syncp); v = *(((u64 *)bhptr) + offt); - } while (u64_stats_fetch_retry_irq(syncp, start)); + } while (u64_stats_fetch_retry(syncp, start)); return v; } diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 6da16ae6a962..4517d2bd186a 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -61,7 +61,9 @@ static bool bpf_tcp_ca_is_valid_access(int off, int size, if (!bpf_tracing_btf_ctx_access(off, size, type, prog, info)) return false; - if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id) + if (base_type(info->reg_type) == PTR_TO_BTF_ID && + !bpf_type_has_unsafe_modifiers(info->reg_type) && + info->btf_id == sock_id) /* promote it to tcp_sock */ info->btf_id = tcp_sock_id; @@ -69,18 +71,17 @@ static bool bpf_tcp_ca_is_valid_access(int off, int size, } static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, - u32 *next_btf_id, - enum bpf_type_flag *flag) + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, + u32 *next_btf_id, enum bpf_type_flag *flag) { + const struct btf_type *t; size_t end; if (atype == BPF_READ) - return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, - flag); + return btf_struct_access(log, reg, off, size, atype, next_btf_id, flag); + t = btf_type_by_id(reg->btf, reg->btf_id); if (t != tcp_sock_type) { bpf_log(log, "only read is supported\n"); return -EACCES; diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c index e8b9a9202fec..b0acf6e19aed 100644 --- a/net/ipv4/devinet.c +++ b/net/ipv4/devinet.c @@ -234,13 +234,20 @@ static void inet_free_ifa(struct in_ifaddr *ifa) call_rcu(&ifa->rcu_head, inet_rcu_free_ifa); } +static void in_dev_free_rcu(struct rcu_head *head) +{ + struct in_device *idev = container_of(head, struct in_device, rcu_head); + + kfree(rcu_dereference_protected(idev->mc_hash, 1)); + kfree(idev); +} + void in_dev_finish_destroy(struct in_device *idev) { struct net_device *dev = idev->dev; WARN_ON(idev->ifa_list); WARN_ON(idev->mc_list); - kfree(rcu_dereference_protected(idev->mc_hash, 1)); #ifdef NET_REFCNT_DEBUG pr_debug("%s: %p=%s\n", __func__, idev, dev ? dev->name : "NIL"); #endif @@ -248,7 +255,7 @@ void in_dev_finish_destroy(struct in_device *idev) if (!idev->dead) pr_err("Freeing alive in_device %p\n", idev); else - kfree(idev); + call_rcu(&idev->rcu_head, in_dev_free_rcu); } EXPORT_SYMBOL(in_dev_finish_destroy); @@ -298,12 +305,6 @@ out_kfree: goto out; } -static void in_dev_rcu_put(struct rcu_head *head) -{ - struct in_device *idev = container_of(head, struct in_device, rcu_head); - in_dev_put(idev); -} - static void inetdev_destroy(struct in_device *in_dev) { struct net_device *dev; @@ -328,7 +329,7 @@ static void inetdev_destroy(struct in_device *in_dev) neigh_parms_release(&arp_tbl, in_dev->arp_parms); arp_ifdown(dev); - call_rcu(&in_dev->rcu_head, in_dev_rcu_put); + in_dev_put(in_dev); } int inet_addr_onlink(struct in_device *in_dev, __be32 a, __be32 b) diff --git a/net/ipv4/esp4_offload.c b/net/ipv4/esp4_offload.c index 170152772d33..3969fa805679 100644 --- a/net/ipv4/esp4_offload.c +++ b/net/ipv4/esp4_offload.c @@ -314,6 +314,9 @@ static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_ xo->seq.low += skb_shinfo(skb)->gso_segs; } + if (xo->seq.low < seq) + xo->seq.hi++; + esp.seqno = cpu_to_be64(seq + ((u64)xo->seq.hi << 32)); ip_hdr(skb)->tot_len = htons(skb->len); diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c index f361d3d56be2..b5736ef16ed2 100644 --- a/net/ipv4/fib_frontend.c +++ b/net/ipv4/fib_frontend.c @@ -841,6 +841,9 @@ static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, return -EINVAL; } + if (!cfg->fc_table) + cfg->fc_table = RT_TABLE_MAIN; + return 0; errout: return err; diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c index f721c308248b..3bb890a40ed7 100644 --- a/net/ipv4/fib_semantics.c +++ b/net/ipv4/fib_semantics.c @@ -30,6 +30,7 @@ #include <linux/slab.h> #include <linux/netlink.h> #include <linux/hash.h> +#include <linux/nospec.h> #include <net/arp.h> #include <net/inet_dscp.h> @@ -423,6 +424,7 @@ static struct fib_info *fib_find_info(struct fib_info *nfi) nfi->fib_prefsrc == fi->fib_prefsrc && nfi->fib_priority == fi->fib_priority && nfi->fib_type == fi->fib_type && + nfi->fib_tb_id == fi->fib_tb_id && memcmp(nfi->fib_metrics, fi->fib_metrics, sizeof(u32) * RTAX_MAX) == 0 && !((nfi->fib_flags ^ fi->fib_flags) & ~RTNH_COMPARE_MASK) && @@ -888,9 +890,11 @@ int fib_nh_match(struct net *net, struct fib_config *cfg, struct fib_info *fi, return 1; } - /* cannot match on nexthop object attributes */ - if (fi->nh) - return 1; + if (fi->nh) { + if (cfg->fc_oif || cfg->fc_gw_family || cfg->fc_mp) + return 1; + return 0; + } if (cfg->fc_oif || cfg->fc_gw_family) { struct fib_nh *nh; @@ -1019,6 +1023,7 @@ bool fib_metrics_match(struct fib_config *cfg, struct fib_info *fi) if (type > RTAX_MAX) return false; + type = array_index_nospec(type, RTAX_MAX + 1); if (type == RTAX_CC_ALGO) { char tmp[TCP_CA_NAME_MAX]; bool ecn_ca = false; diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c index 452ff177e4da..74d403dbd2b4 100644 --- a/net/ipv4/fib_trie.c +++ b/net/ipv4/fib_trie.c @@ -126,7 +126,7 @@ struct key_vector { /* This list pointer if valid if (pos | bits) == 0 (LEAF) */ struct hlist_head leaf; /* This array is valid if (pos | bits) > 0 (TNODE) */ - struct key_vector __rcu *tnode[0]; + DECLARE_FLEX_ARRAY(struct key_vector __rcu *, tnode); }; }; @@ -1381,8 +1381,10 @@ int fib_table_insert(struct net *net, struct fib_table *tb, /* The alias was already inserted, so the node must exist. */ l = l ? l : fib_find_node(t, &tp, key); - if (WARN_ON_ONCE(!l)) + if (WARN_ON_ONCE(!l)) { + err = -ENOENT; goto out_free_new_fa; + } if (fib_find_alias(&l->leaf, new_fa->fa_slen, 0, 0, tb->tb_id, true) == new_fa) { diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c index d5d745c3e345..46aa2d65e40a 100644 --- a/net/ipv4/icmp.c +++ b/net/ipv4/icmp.c @@ -263,7 +263,7 @@ bool icmp_global_allow(void) /* We want to use a credit of one in average, but need to randomize * it for security reasons. */ - credit = max_t(int, credit - prandom_u32_max(3), 0); + credit = max_t(int, credit - get_random_u32_below(3), 0); rc = true; } WRITE_ONCE(icmp_global.credit, credit); diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c index 81be3e0f0e70..c920aa9a62a9 100644 --- a/net/ipv4/igmp.c +++ b/net/ipv4/igmp.c @@ -213,7 +213,7 @@ static void igmp_stop_timer(struct ip_mc_list *im) /* It must be called with locked im->lock */ static void igmp_start_timer(struct ip_mc_list *im, int max_delay) { - int tv = prandom_u32_max(max_delay); + int tv = get_random_u32_below(max_delay); im->tm_running = 1; if (!mod_timer(&im->timer, jiffies+tv+2)) @@ -222,7 +222,7 @@ static void igmp_start_timer(struct ip_mc_list *im, int max_delay) static void igmp_gq_start_timer(struct in_device *in_dev) { - int tv = prandom_u32_max(in_dev->mr_maxdelay); + int tv = get_random_u32_below(in_dev->mr_maxdelay); unsigned long exp = jiffies + tv + 2; if (in_dev->mr_gq_running && @@ -236,7 +236,7 @@ static void igmp_gq_start_timer(struct in_device *in_dev) static void igmp_ifc_start_timer(struct in_device *in_dev, int delay) { - int tv = prandom_u32_max(delay); + int tv = get_random_u32_below(delay); if (!mod_timer(&in_dev->mr_ifc_timer, jiffies+tv+2)) in_dev_hold(in_dev); diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 4e84ed21d16f..d1f837579398 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -173,22 +173,40 @@ static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2, return false; } +static bool __inet_bhash2_conflict(const struct sock *sk, struct sock *sk2, + kuid_t sk_uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) +{ + if (sk->sk_family == AF_INET && ipv6_only_sock(sk2)) + return false; + + return inet_bind_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok); +} + static bool inet_bhash2_conflict(const struct sock *sk, const struct inet_bind2_bucket *tb2, kuid_t sk_uid, bool relax, bool reuseport_cb_ok, bool reuseport_ok) { + struct inet_timewait_sock *tw2; struct sock *sk2; sk_for_each_bound_bhash2(sk2, &tb2->owners) { - if (sk->sk_family == AF_INET && ipv6_only_sock(sk2)) - continue; + if (__inet_bhash2_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok)) + return true; + } + + twsk_for_each_bound_bhash2(tw2, &tb2->deathrow) { + sk2 = (struct sock *)tw2; - if (inet_bind_conflict(sk, sk2, sk_uid, relax, - reuseport_cb_ok, reuseport_ok)) + if (__inet_bhash2_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok)) return true; } + return false; } @@ -314,7 +332,7 @@ other_half_scan: if (likely(remaining > 1)) remaining &= ~1U; - offset = prandom_u32_max(remaining); + offset = get_random_u32_below(remaining); /* __inet_hash_connect() favors ports having @low parity * We do the opposite to not pollute connect() users. */ @@ -471,11 +489,11 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum) bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; bool found_port = false, check_bind_conflict = true; bool bhash_created = false, bhash2_created = false; + int ret = -EADDRINUSE, port = snum, l3mdev; struct inet_bind_hashbucket *head, *head2; struct inet_bind2_bucket *tb2 = NULL; struct inet_bind_bucket *tb = NULL; bool head2_lock_acquired = false; - int ret = 1, port = snum, l3mdev; struct net *net = sock_net(sk); l3mdev = inet_sk_bound_l3mdev(sk); @@ -1182,11 +1200,25 @@ void inet_csk_prepare_forced_close(struct sock *sk) } EXPORT_SYMBOL(inet_csk_prepare_forced_close); +static int inet_ulp_can_listen(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ulp_ops && !icsk->icsk_ulp_ops->clone) + return -EINVAL; + + return 0; +} + int inet_csk_listen_start(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct inet_sock *inet = inet_sk(sk); - int err = -EADDRINUSE; + int err; + + err = inet_ulp_can_listen(sk); + if (unlikely(err)) + return err; reqsk_queue_alloc(&icsk->icsk_accept_queue); @@ -1202,7 +1234,8 @@ int inet_csk_listen_start(struct sock *sk) * after validation is complete. */ inet_sk_state_store(sk, TCP_LISTEN); - if (!sk->sk_prot->get_port(sk, inet->inet_num)) { + err = sk->sk_prot->get_port(sk, inet->inet_num); + if (!err) { inet->inet_sport = htons(inet->inet_num); sk_dst_reset(sk); diff --git a/net/ipv4/inet_fragment.c b/net/ipv4/inet_fragment.c index c9f9ac5013a7..7072fc0783ef 100644 --- a/net/ipv4/inet_fragment.c +++ b/net/ipv4/inet_fragment.c @@ -133,6 +133,7 @@ static void inet_frags_free_cb(void *ptr, void *arg) count = del_timer_sync(&fq->timer) ? 1 : 0; spin_lock_bh(&fq->lock); + fq->flags |= INET_FRAG_DROP; if (!(fq->flags & INET_FRAG_COMPLETE)) { fq->flags |= INET_FRAG_COMPLETE; count++; @@ -260,7 +261,8 @@ static void inet_frag_destroy_rcu(struct rcu_head *head) kmem_cache_free(f->frags_cachep, q); } -unsigned int inet_frag_rbtree_purge(struct rb_root *root) +unsigned int inet_frag_rbtree_purge(struct rb_root *root, + enum skb_drop_reason reason) { struct rb_node *p = rb_first(root); unsigned int sum = 0; @@ -274,7 +276,7 @@ unsigned int inet_frag_rbtree_purge(struct rb_root *root) struct sk_buff *next = FRAG_CB(skb)->next_frag; sum += skb->truesize; - kfree_skb(skb); + kfree_skb_reason(skb, reason); skb = next; } } @@ -284,17 +286,21 @@ EXPORT_SYMBOL(inet_frag_rbtree_purge); void inet_frag_destroy(struct inet_frag_queue *q) { - struct fqdir *fqdir; unsigned int sum, sum_truesize = 0; + enum skb_drop_reason reason; struct inet_frags *f; + struct fqdir *fqdir; WARN_ON(!(q->flags & INET_FRAG_COMPLETE)); + reason = (q->flags & INET_FRAG_DROP) ? + SKB_DROP_REASON_FRAG_REASM_TIMEOUT : + SKB_CONSUMED; WARN_ON(del_timer(&q->timer) != 0); /* Release all fragment data. */ fqdir = q->fqdir; f = fqdir->f; - sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments); + sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments, reason); sum = sum_truesize + f->qsize; call_rcu(&q->rcu, inet_frag_destroy_rcu); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index d3dc28156622..f58d73888638 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -116,6 +116,7 @@ static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, #endif tb->rcv_saddr = sk->sk_rcv_saddr; INIT_HLIST_HEAD(&tb->owners); + INIT_HLIST_HEAD(&tb->deathrow); hlist_add_head(&tb->node, &head->chain); } @@ -137,7 +138,7 @@ struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, /* Caller must hold hashbucket lock for this tb with local BH disabled */ void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) { - if (hlist_empty(&tb->owners)) { + if (hlist_empty(&tb->owners) && hlist_empty(&tb->deathrow)) { __hlist_del(&tb->node); kmem_cache_free(cachep, tb); } @@ -649,8 +650,20 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) spin_lock(lock); if (osk) { WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); - ret = sk_nulls_del_node_init_rcu(osk); - } else if (found_dup_sk) { + ret = sk_hashed(osk); + if (ret) { + /* Before deleting the node, we insert a new one to make + * sure that the look-up-sk process would not miss either + * of them and that at least one node would exist in ehash + * table all the time. Otherwise there's a tiny chance + * that lookup process could find nothing in ehash table. + */ + __sk_nulls_add_node_tail_rcu(sk, list); + sk_nulls_del_node_init_rcu(osk); + } + goto unlock; + } + if (found_dup_sk) { *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); if (*found_dup_sk) ret = false; @@ -659,6 +672,7 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) if (ret) __sk_nulls_add_node_rcu(sk, list); +unlock: spin_unlock(lock); return ret; @@ -858,34 +872,80 @@ inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, in return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; } -int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk) +static void inet_update_saddr(struct sock *sk, void *saddr, int family) +{ + if (family == AF_INET) { + inet_sk(sk)->inet_saddr = *(__be32 *)saddr; + sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr); + } +#if IS_ENABLED(CONFIG_IPV6) + else { + sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr; + } +#endif +} + +static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) { struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; struct inet_bind2_bucket *tb2, *new_tb2; int l3mdev = inet_sk_bound_l3mdev(sk); - struct inet_bind_hashbucket *head2; int port = inet_sk(sk)->inet_num; struct net *net = sock_net(sk); + int bhash; + + if (!inet_csk(sk)->icsk_bind2_hash) { + /* Not bind()ed before. */ + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); + + return 0; + } /* Allocate a bind2 bucket ahead of time to avoid permanently putting * the bhash2 table in an inconsistent state if a new tb2 bucket * allocation fails. */ new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); - if (!new_tb2) + if (!new_tb2) { + if (reset) { + /* The (INADDR_ANY, port) bucket might have already + * been freed, then we cannot fixup icsk_bind2_hash, + * so we give up and unlink sk from bhash/bhash2 not + * to leave inconsistency in bhash2. + */ + inet_put_port(sk); + inet_reset_saddr(sk); + } + return -ENOMEM; + } + bhash = inet_bhashfn(net, port, hinfo->bhash_size); + head = &hinfo->bhash[bhash]; head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); - if (prev_saddr) { - spin_lock_bh(&prev_saddr->lock); - __sk_del_bind2_node(sk); - inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, - inet_csk(sk)->icsk_bind2_hash); - spin_unlock_bh(&prev_saddr->lock); - } + /* If we change saddr locklessly, another thread + * iterating over bhash might see corrupted address. + */ + spin_lock_bh(&head->lock); - spin_lock_bh(&head2->lock); + spin_lock(&head2->lock); + __sk_del_bind2_node(sk); + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); + spin_unlock(&head2->lock); + + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + spin_lock(&head2->lock); tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); if (!tb2) { tb2 = new_tb2; @@ -893,26 +953,40 @@ int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct soc } sk_add_bind2_node(sk, &tb2->owners); inet_csk(sk)->icsk_bind2_hash = tb2; - spin_unlock_bh(&head2->lock); + spin_unlock(&head2->lock); + + spin_unlock_bh(&head->lock); if (tb2 != new_tb2) kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); return 0; } + +int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) +{ + return __inet_bhash2_update_saddr(sk, saddr, family, false); +} EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); +void inet_bhash2_reset_saddr(struct sock *sk) +{ + if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) + __inet_bhash2_update_saddr(sk, NULL, 0, true); +} +EXPORT_SYMBOL_GPL(inet_bhash2_reset_saddr); + /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm * Note that we use 32bit integers (vs RFC 'short integers') * because 2^16 is not a multiple of num_ephemeral and this * property might be used by clever attacker. + * * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though - * attacks were since demonstrated, thus we use 65536 instead to really - * give more isolation and privacy, at the expense of 256kB of kernel - * memory. + * attacks were since demonstrated, thus we use 65536 by default instead + * to really give more isolation and privacy, at the expense of 256kB + * of kernel memory. */ -#define INET_TABLE_PERTURB_SHIFT 16 -#define INET_TABLE_PERTURB_SIZE (1 << INET_TABLE_PERTURB_SHIFT) +#define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER) static u32 *table_perturb; int __inet_hash_connect(struct inet_timewait_death_row *death_row, @@ -1037,21 +1111,22 @@ ok: * on low contention the randomness is maximal and on high contention * it may be inexistent. */ - i = max_t(int, i, prandom_u32_max(8) * 2); + i = max_t(int, i, get_random_u32_below(8) * 2); WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); /* Head lock still held and bh's disabled */ inet_bind_hash(sk, tb, tb2, port); - spin_unlock(&head2->lock); - if (sk_unhashed(sk)) { inet_sk(sk)->inet_sport = htons(port); inet_ehash_nolisten(sk, (struct sock *)tw, NULL); } if (tw) inet_twsk_bind_unhash(tw, hinfo); + + spin_unlock(&head2->lock); spin_unlock(&head->lock); + if (tw) inet_twsk_deschedule_put(tw); local_bh_enable(); diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c index 66fc940f9521..beed32fff484 100644 --- a/net/ipv4/inet_timewait_sock.c +++ b/net/ipv4/inet_timewait_sock.c @@ -29,6 +29,7 @@ void inet_twsk_bind_unhash(struct inet_timewait_sock *tw, struct inet_hashinfo *hashinfo) { + struct inet_bind2_bucket *tb2 = tw->tw_tb2; struct inet_bind_bucket *tb = tw->tw_tb; if (!tb) @@ -37,6 +38,11 @@ void inet_twsk_bind_unhash(struct inet_timewait_sock *tw, __hlist_del(&tw->tw_bind_node); tw->tw_tb = NULL; inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + + __hlist_del(&tw->tw_bind2_node); + tw->tw_tb2 = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + __sock_put((struct sock *)tw); } @@ -45,7 +51,7 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw) { struct inet_hashinfo *hashinfo = tw->tw_dr->hashinfo; spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash); - struct inet_bind_hashbucket *bhead; + struct inet_bind_hashbucket *bhead, *bhead2; spin_lock(lock); sk_nulls_del_node_init_rcu((struct sock *)tw); @@ -54,9 +60,13 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw) /* Disassociate with bind bucket. */ bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), tw->tw_num, hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, (struct sock *)tw, + twsk_net(tw), tw->tw_num); spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); inet_twsk_bind_unhash(tw, hashinfo); + spin_unlock(&bhead2->lock); spin_unlock(&bhead->lock); refcount_dec(&tw->tw_dr->tw_refcount); @@ -81,10 +91,10 @@ void inet_twsk_put(struct inet_timewait_sock *tw) } EXPORT_SYMBOL_GPL(inet_twsk_put); -static void inet_twsk_add_node_rcu(struct inet_timewait_sock *tw, - struct hlist_nulls_head *list) +static void inet_twsk_add_node_tail_rcu(struct inet_timewait_sock *tw, + struct hlist_nulls_head *list) { - hlist_nulls_add_head_rcu(&tw->tw_node, list); + hlist_nulls_add_tail_rcu(&tw->tw_node, list); } static void inet_twsk_add_bind_node(struct inet_timewait_sock *tw, @@ -93,6 +103,12 @@ static void inet_twsk_add_bind_node(struct inet_timewait_sock *tw, hlist_add_head(&tw->tw_bind_node, list); } +static void inet_twsk_add_bind2_node(struct inet_timewait_sock *tw, + struct hlist_head *list) +{ + hlist_add_head(&tw->tw_bind2_node, list); +} + /* * Enter the time wait state. This is called with locally disabled BH. * Essentially we whip up a timewait bucket, copy the relevant info into it @@ -105,22 +121,33 @@ void inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk, const struct inet_connection_sock *icsk = inet_csk(sk); struct inet_ehash_bucket *ehead = inet_ehash_bucket(hashinfo, sk->sk_hash); spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); - struct inet_bind_hashbucket *bhead; + struct inet_bind_hashbucket *bhead, *bhead2; + /* Step 1: Put TW into bind hash. Original socket stays there too. Note, that any socket with inet->num != 0 MUST be bound in binding cache, even if it is closed. */ bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), inet->inet_num, hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, sk, twsk_net(tw), inet->inet_num); + spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); + tw->tw_tb = icsk->icsk_bind_hash; WARN_ON(!icsk->icsk_bind_hash); inet_twsk_add_bind_node(tw, &tw->tw_tb->owners); + + tw->tw_tb2 = icsk->icsk_bind2_hash; + WARN_ON(!icsk->icsk_bind2_hash); + inet_twsk_add_bind2_node(tw, &tw->tw_tb2->deathrow); + + spin_unlock(&bhead2->lock); spin_unlock(&bhead->lock); spin_lock(lock); - inet_twsk_add_node_rcu(tw, &ehead->chain); + inet_twsk_add_node_tail_rcu(tw, &ehead->chain); /* Step 3: Remove SK from hash chain */ if (__sk_nulls_del_node_init_rcu(sk)) diff --git a/net/ipv4/ip_fragment.c b/net/ipv4/ip_fragment.c index fb153569889e..69c00ffdcf3e 100644 --- a/net/ipv4/ip_fragment.c +++ b/net/ipv4/ip_fragment.c @@ -153,6 +153,7 @@ static void ip_expire(struct timer_list *t) if (qp->q.flags & INET_FRAG_COMPLETE) goto out; + qp->q.flags |= INET_FRAG_DROP; ipq_kill(qp); __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); __IP_INC_STATS(net, IPSTATS_MIB_REASMTIMEOUT); @@ -194,7 +195,7 @@ out: spin_unlock(&qp->q.lock); out_rcu_unlock: rcu_read_unlock(); - kfree_skb(head); + kfree_skb_reason(head, SKB_DROP_REASON_FRAG_REASM_TIMEOUT); ipq_put(qp); } @@ -254,7 +255,8 @@ static int ip_frag_reinit(struct ipq *qp) return -ETIMEDOUT; } - sum_truesize = inet_frag_rbtree_purge(&qp->q.rb_fragments); + sum_truesize = inet_frag_rbtree_purge(&qp->q.rb_fragments, + SKB_DROP_REASON_FRAG_TOO_FAR); sub_frag_mem_limit(qp->q.fqdir, sum_truesize); qp->q.flags = 0; @@ -278,10 +280,14 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) struct net_device *dev; unsigned int fragsize; int err = -ENOENT; + SKB_DR(reason); u8 ecn; - if (qp->q.flags & INET_FRAG_COMPLETE) + /* If reassembly is already done, @skb must be a duplicate frag. */ + if (qp->q.flags & INET_FRAG_COMPLETE) { + SKB_DR_SET(reason, DUP_FRAG); goto err; + } if (!(IPCB(skb)->flags & IPSKB_FRAG_COMPLETE) && unlikely(ip_frag_too_far(qp)) && @@ -382,8 +388,9 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) insert_error: if (err == IPFRAG_DUP) { - kfree_skb(skb); - return -EINVAL; + SKB_DR_SET(reason, DUP_FRAG); + err = -EINVAL; + goto err; } err = -EINVAL; __IP_INC_STATS(net, IPSTATS_MIB_REASM_OVERLAPS); @@ -391,7 +398,7 @@ discard_qp: inet_frag_kill(&qp->q); __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); err: - kfree_skb(skb); + kfree_skb_reason(skb, reason); return err; } diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c index f866d6282b2b..ffff46cdcb58 100644 --- a/net/ipv4/ip_gre.c +++ b/net/ipv4/ip_gre.c @@ -510,7 +510,7 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, err_free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); } static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) @@ -592,7 +592,7 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) err_free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); } static int gre_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) @@ -663,7 +663,7 @@ static netdev_tx_t ipgre_xmit(struct sk_buff *skb, free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -717,7 +717,7 @@ static netdev_tx_t erspan_xmit(struct sk_buff *skb, free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -745,7 +745,7 @@ static netdev_tx_t gre_tap_xmit(struct sk_buff *skb, free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -1492,24 +1492,6 @@ static int ipgre_fill_info(struct sk_buff *skb, const struct net_device *dev) struct ip_tunnel_parm *p = &t->parms; __be16 o_flags = p->o_flags; - if (t->erspan_ver <= 2) { - if (t->erspan_ver != 0 && !t->collect_md) - o_flags |= TUNNEL_KEY; - - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) - goto nla_put_failure; - - if (t->erspan_ver == 1) { - if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) - goto nla_put_failure; - } else if (t->erspan_ver == 2) { - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) - goto nla_put_failure; - if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) - goto nla_put_failure; - } - } - if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || nla_put_be16(skb, IFLA_GRE_IFLAGS, gre_tnl_flags_to_gre_flags(p->i_flags)) || @@ -1550,6 +1532,34 @@ nla_put_failure: return -EMSGSIZE; } +static int erspan_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (t->erspan_ver <= 2) { + if (t->erspan_ver != 0 && !t->collect_md) + t->parms.o_flags |= TUNNEL_KEY; + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) + goto nla_put_failure; + + if (t->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) + goto nla_put_failure; + } else if (t->erspan_ver == 2) { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) + goto nla_put_failure; + } + } + + return ipgre_fill_info(skb, dev); + +nla_put_failure: + return -EMSGSIZE; +} + static void erspan_setup(struct net_device *dev) { struct ip_tunnel *t = netdev_priv(dev); @@ -1628,7 +1638,7 @@ static struct rtnl_link_ops erspan_link_ops __read_mostly = { .changelink = erspan_changelink, .dellink = ip_tunnel_dellink, .get_size = ipgre_get_size, - .fill_info = ipgre_fill_info, + .fill_info = erspan_fill_info, .get_link_net = ip_tunnel_get_link_net, }; @@ -1665,7 +1675,7 @@ struct net_device *gretap_fb_dev_create(struct net *net, const char *name, if (err) goto out; - err = rtnl_configure_link(dev, NULL); + err = rtnl_configure_link(dev, NULL, 0, NULL); if (err < 0) goto out; diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c index 1b512390b3cf..e880ce77322a 100644 --- a/net/ipv4/ip_input.c +++ b/net/ipv4/ip_input.c @@ -366,6 +366,11 @@ static int ip_rcv_finish_core(struct net *net, struct sock *sk, iph->tos, dev); if (unlikely(err)) goto drop_error; + } else { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + if (in_dev && IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; } #ifdef CONFIG_IP_ROUTE_CLASSID diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index 6e19cad154f5..9f92ae35bb01 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -267,7 +267,7 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc, } #endif if (cmsg->cmsg_level == SOL_SOCKET) { - err = __sock_cmsg_send(sk, msg, cmsg, &ipc->sockc); + err = __sock_cmsg_send(sk, cmsg, &ipc->sockc); if (err) return err; continue; @@ -433,6 +433,7 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, } kfree_skb(skb); } +EXPORT_SYMBOL_GPL(ip_icmp_error); void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 port, u32 info) { diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c index 019f3b0839c5..de90b09dfe78 100644 --- a/net/ipv4/ip_tunnel.c +++ b/net/ipv4/ip_tunnel.c @@ -368,23 +368,23 @@ int ip_tunnel_rcv(struct ip_tunnel *tunnel, struct sk_buff *skb, #ifdef CONFIG_NET_IPGRE_BROADCAST if (ipv4_is_multicast(iph->daddr)) { - tunnel->dev->stats.multicast++; + DEV_STATS_INC(tunnel->dev, multicast); skb->pkt_type = PACKET_BROADCAST; } #endif if ((!(tpi->flags&TUNNEL_CSUM) && (tunnel->parms.i_flags&TUNNEL_CSUM)) || ((tpi->flags&TUNNEL_CSUM) && !(tunnel->parms.i_flags&TUNNEL_CSUM))) { - tunnel->dev->stats.rx_crc_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_crc_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } if (tunnel->parms.i_flags&TUNNEL_SEQ) { if (!(tpi->flags&TUNNEL_SEQ) || (tunnel->i_seqno && (s32)(ntohl(tpi->seq) - tunnel->i_seqno) < 0)) { - tunnel->dev->stats.rx_fifo_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_fifo_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } tunnel->i_seqno = ntohl(tpi->seq) + 1; @@ -398,8 +398,8 @@ int ip_tunnel_rcv(struct ip_tunnel *tunnel, struct sk_buff *skb, net_info_ratelimited("non-ECT from %pI4 with TOS=%#x\n", &iph->saddr, iph->tos); if (err > 1) { - ++tunnel->dev->stats.rx_frame_errors; - ++tunnel->dev->stats.rx_errors; + DEV_STATS_INC(tunnel->dev, rx_frame_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } } @@ -581,7 +581,7 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, if (!rt) { rt = ip_route_output_key(tunnel->net, &fl4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error; } if (use_cache) @@ -590,7 +590,7 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, } if (rt->dst.dev == dev) { ip_rt_put(rt); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } @@ -625,10 +625,10 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, df, !net_eq(tunnel->net, dev_net(dev))); return; tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); goto kfree; tx_dropped: - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); kfree: kfree_skb(skb); } @@ -662,7 +662,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, /* NBMA tunnel */ if (!skb_dst(skb)) { - dev->stats.tx_fifo_errors++; + DEV_STATS_INC(dev, tx_fifo_errors); goto tx_error; } @@ -749,7 +749,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, rt = ip_route_output_key(tunnel->net, &fl4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error; } if (use_cache) @@ -762,7 +762,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, if (rt->dst.dev == dev) { ip_rt_put(rt); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } @@ -805,7 +805,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, if (skb_cow_head(skb, dev->needed_headroom)) { ip_rt_put(rt); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return; } @@ -819,7 +819,7 @@ tx_error_icmp: dst_link_failure(skb); #endif tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); } EXPORT_SYMBOL_GPL(ip_tunnel_xmit); diff --git a/net/ipv4/ip_vti.c b/net/ipv4/ip_vti.c index 8c2bd1d9ddce..53bfd8af6920 100644 --- a/net/ipv4/ip_vti.c +++ b/net/ipv4/ip_vti.c @@ -107,8 +107,8 @@ static int vti_rcv_cb(struct sk_buff *skb, int err) dev = tunnel->dev; if (err) { - dev->stats.rx_errors++; - dev->stats.rx_dropped++; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_dropped); return 0; } @@ -183,7 +183,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, fl->u.ip4.flowi4_flags |= FLOWI_FLAG_ANYSRC; rt = __ip_route_output_key(dev_net(dev), &fl->u.ip4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } dst = &rt->dst; @@ -198,14 +198,14 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, if (dst->error) { dst_release(dst); dst = NULL; - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } skb_dst_set(skb, dst); break; #endif default: - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } } @@ -213,7 +213,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, dst_hold(dst); dst = xfrm_lookup_route(tunnel->net, dst, fl, NULL, 0); if (IS_ERR(dst)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } @@ -221,7 +221,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, goto xmit; if (!vti_state_check(dst->xfrm, parms->iph.daddr, parms->iph.saddr)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); dst_release(dst); goto tx_error_icmp; } @@ -230,7 +230,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, if (tdev == dev) { dst_release(dst); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } @@ -267,7 +267,7 @@ xmit: tx_error_icmp: dst_link_failure(skb); tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; } @@ -304,7 +304,7 @@ static netdev_tx_t vti_tunnel_xmit(struct sk_buff *skb, struct net_device *dev) return vti_xmit(skb, dev, &fl); tx_err: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; } diff --git a/net/ipv4/ipip.c b/net/ipv4/ipip.c index 180f9daf5bec..abea77759b7e 100644 --- a/net/ipv4/ipip.c +++ b/net/ipv4/ipip.c @@ -310,7 +310,7 @@ static netdev_tx_t ipip_tunnel_xmit(struct sk_buff *skb, tx_error: kfree_skb(skb); - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); return NETDEV_TX_OK; } diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index e04544ac4b45..eec1f6df80d8 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -412,7 +412,7 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) static void ipmr_free_table(struct mr_table *mrt) { - del_timer_sync(&mrt->ipmr_expire_timer); + timer_shutdown_sync(&mrt->ipmr_expire_timer); mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC | MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC); rhltable_destroy(&mrt->mfc_hash); @@ -506,8 +506,8 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, struct net_device *dev) return err; } - dev->stats.tx_bytes += skb->len; - dev->stats.tx_packets++; + DEV_STATS_ADD(dev, tx_bytes, skb->len); + DEV_STATS_INC(dev, tx_packets); rcu_read_lock(); /* Pairs with WRITE_ONCE() in vif_add() and vif_delete() */ @@ -1839,8 +1839,8 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, if (vif->flags & VIFF_REGISTER) { WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); - vif_dev->stats.tx_bytes += skb->len; - vif_dev->stats.tx_packets++; + DEV_STATS_ADD(vif_dev, tx_bytes, skb->len); + DEV_STATS_INC(vif_dev, tx_packets); ipmr_cache_report(mrt, skb, vifi, IGMPMSG_WHOLEPKT); goto out_free; } @@ -1898,8 +1898,8 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, if (vif->flags & VIFF_TUNNEL) { ip_encap(net, skb, vif->local, vif->remote); /* FIXME: extra output firewall step used to be here. --RR */ - vif_dev->stats.tx_packets++; - vif_dev->stats.tx_bytes += skb->len; + DEV_STATS_INC(vif_dev, tx_packets); + DEV_STATS_ADD(vif_dev, tx_bytes, skb->len); } IPCB(skb)->flags |= IPSKB_FORWARDED; diff --git a/net/ipv4/metrics.c b/net/ipv4/metrics.c index 25ea6ac44db9..0e3ee1532848 100644 --- a/net/ipv4/metrics.c +++ b/net/ipv4/metrics.c @@ -1,5 +1,6 @@ // SPDX-License-Identifier: GPL-2.0-only #include <linux/netlink.h> +#include <linux/nospec.h> #include <linux/rtnetlink.h> #include <linux/types.h> #include <net/ip.h> @@ -14,9 +15,6 @@ static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, struct nlattr *nla; int remaining; - if (!fc_mx) - return 0; - nla_for_each_attr(nla, fc_mx, fc_mx_len, remaining) { int type = nla_type(nla); u32 val; @@ -28,6 +26,7 @@ static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, return -EINVAL; } + type = array_index_nospec(type, RTAX_MAX + 1); if (type == RTAX_CC_ALGO) { char tmp[TCP_CA_NAME_MAX]; diff --git a/net/ipv4/netfilter/ipt_CLUSTERIP.c b/net/ipv4/netfilter/ipt_CLUSTERIP.c index f8e176c77d1c..b3cc416ed292 100644 --- a/net/ipv4/netfilter/ipt_CLUSTERIP.c +++ b/net/ipv4/netfilter/ipt_CLUSTERIP.c @@ -435,7 +435,7 @@ clusterip_tg(struct sk_buff *skb, const struct xt_action_param *par) switch (ctinfo) { case IP_CT_NEW: - ct->mark = hash; + WRITE_ONCE(ct->mark, hash); break; case IP_CT_RELATED: case IP_CT_RELATED_REPLY: @@ -452,7 +452,7 @@ clusterip_tg(struct sk_buff *skb, const struct xt_action_param *par) #ifdef DEBUG nf_ct_dump_tuple_ip(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); #endif - pr_debug("hash=%u ct_hash=%u ", hash, ct->mark); + pr_debug("hash=%u ct_hash=%u ", hash, READ_ONCE(ct->mark)); if (!clusterip_responsible(cipinfo->config, hash)) { pr_debug("not responsible\n"); return NF_DROP; diff --git a/net/ipv4/netfilter/nft_dup_ipv4.c b/net/ipv4/netfilter/nft_dup_ipv4.c index 0bcd6aee6000..a522c3a3be52 100644 --- a/net/ipv4/netfilter/nft_dup_ipv4.c +++ b/net/ipv4/netfilter/nft_dup_ipv4.c @@ -52,7 +52,8 @@ static int nft_dup_ipv4_init(const struct nft_ctx *ctx, return err; } -static int nft_dup_ipv4_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_dup_ipv4_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_dup_ipv4 *priv = nft_expr_priv(expr); diff --git a/net/ipv4/netfilter/nft_fib_ipv4.c b/net/ipv4/netfilter/nft_fib_ipv4.c index fc65d69f23e1..9eee535c64dd 100644 --- a/net/ipv4/netfilter/nft_fib_ipv4.c +++ b/net/ipv4/netfilter/nft_fib_ipv4.c @@ -138,12 +138,11 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, break; } - if (!oif) { - found = FIB_RES_DEV(res); + if (!oif) { + found = FIB_RES_DEV(res); } else { if (!fib_info_nh_uses_dev(res.fi, oif)) return; - found = oif; } diff --git a/net/ipv4/ping.c b/net/ipv4/ping.c index bde333b24837..409ec2a1f95b 100644 --- a/net/ipv4/ping.c +++ b/net/ipv4/ping.c @@ -49,6 +49,11 @@ #include <net/transp_v6.h> #endif +#define ping_portaddr_for_each_entry(__sk, node, list) \ + hlist_nulls_for_each_entry(__sk, node, list, sk_nulls_node) +#define ping_portaddr_for_each_entry_rcu(__sk, node, list) \ + hlist_nulls_for_each_entry_rcu(__sk, node, list, sk_nulls_node) + struct ping_table { struct hlist_nulls_head hash[PING_HTABLE_SIZE]; spinlock_t lock; @@ -138,7 +143,7 @@ next_port: fail: spin_unlock(&ping_table.lock); - return 1; + return -EADDRINUSE; } EXPORT_SYMBOL_GPL(ping_get_port); @@ -192,7 +197,7 @@ static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident) return NULL; } - ping_portaddr_for_each_entry(sk, hnode, hslot) { + ping_portaddr_for_each_entry_rcu(sk, hnode, hslot) { isk = inet_sk(sk); pr_debug("iterate\n"); diff --git a/net/ipv4/proc.c b/net/ipv4/proc.c index 5386f460bd20..f88daace9de3 100644 --- a/net/ipv4/proc.c +++ b/net/ipv4/proc.c @@ -297,6 +297,7 @@ static const struct snmp_mib snmp4_net_list[] = { SNMP_MIB_ITEM("TCPDSACKIgnoredDubious", LINUX_MIB_TCPDSACKIGNOREDDUBIOUS), SNMP_MIB_ITEM("TCPMigrateReqSuccess", LINUX_MIB_TCPMIGRATEREQSUCCESS), SNMP_MIB_ITEM("TCPMigrateReqFailure", LINUX_MIB_TCPMIGRATEREQFAILURE), + SNMP_MIB_ITEM("TCPPLBRehash", LINUX_MIB_TCPPLBREHASH), SNMP_MIB_SENTINEL }; diff --git a/net/ipv4/route.c b/net/ipv4/route.c index cd1fa9f70f1a..de6e3515ab4f 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -471,7 +471,7 @@ static u32 ip_idents_reserve(u32 hash, int segs) old = READ_ONCE(*p_tstamp); if (old != now && cmpxchg(p_tstamp, old, now) == old) - delta = prandom_u32_max(now - old); + delta = get_random_u32_below(now - old); /* If UBSAN reports an error there, please make sure your compiler * supports -fno-strict-overflow before reporting it that was a bug @@ -689,7 +689,7 @@ static void update_or_create_fnhe(struct fib_nh_common *nhc, __be32 daddr, } else { /* Randomize max depth to avoid some side channels attacks. */ int max_depth = FNHE_RECLAIM_DEPTH + - prandom_u32_max(FNHE_RECLAIM_DEPTH); + get_random_u32_below(FNHE_RECLAIM_DEPTH); while (depth > max_depth) { fnhe_remove_oldest(hash); diff --git a/net/ipv4/syncookies.c b/net/ipv4/syncookies.c index 942d2dfa1115..26fb97d1d4d9 100644 --- a/net/ipv4/syncookies.c +++ b/net/ipv4/syncookies.c @@ -288,12 +288,11 @@ struct request_sock *cookie_tcp_reqsk_alloc(const struct request_sock_ops *ops, struct tcp_request_sock *treq; struct request_sock *req; -#ifdef CONFIG_MPTCP if (sk_is_mptcp(sk)) - ops = &mptcp_subflow_request_sock_ops; -#endif + req = mptcp_subflow_reqsk_alloc(ops, sk, false); + else + req = inet_reqsk_alloc(ops, sk, false); - req = inet_reqsk_alloc(ops, sk, false); if (!req) return NULL; diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 9b8a6db7a66b..0d0cc4ef2b85 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -40,6 +40,9 @@ static int one_day_secs = 24 * 3600; static u32 fib_multipath_hash_fields_all_mask __maybe_unused = FIB_MULTIPATH_HASH_FIELD_ALL_MASK; static unsigned int tcp_child_ehash_entries_max = 16 * 1024 * 1024; +static unsigned int udp_child_hash_entries_max = UDP_HTABLE_SIZE_MAX; +static int tcp_plb_max_rounds = 31; +static int tcp_plb_max_cong_thresh = 256; /* obsolete */ static int sysctl_tcp_low_latency __read_mostly; @@ -400,12 +403,36 @@ static int proc_tcp_ehash_entries(struct ctl_table *table, int write, if (!net_eq(net, &init_net) && !hinfo->pernet) tcp_ehash_entries *= -1; + memset(&tbl, 0, sizeof(tbl)); tbl.data = &tcp_ehash_entries; tbl.maxlen = sizeof(int); return proc_dointvec(&tbl, write, buffer, lenp, ppos); } +static int proc_udp_hash_entries(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_udp_child_hash_entries); + int udp_hash_entries; + struct ctl_table tbl; + + udp_hash_entries = net->ipv4.udp_table->mask + 1; + + /* A negative number indicates that the child netns + * shares the global udp_table. + */ + if (!net_eq(net, &init_net) && net->ipv4.udp_table == &udp_table) + udp_hash_entries *= -1; + + memset(&tbl, 0, sizeof(tbl)); + tbl.data = &udp_hash_entries; + tbl.maxlen = sizeof(int); + + return proc_dointvec(&tbl, write, buffer, lenp, ppos); +} + #ifdef CONFIG_IP_ROUTE_MULTIPATH static int proc_fib_multipath_hash_policy(struct ctl_table *table, int write, void *buffer, size_t *lenp, @@ -1360,6 +1387,21 @@ static struct ctl_table ipv4_net_table[] = { .extra2 = &tcp_child_ehash_entries_max, }, { + .procname = "udp_hash_entries", + .data = &init_net.ipv4.sysctl_udp_child_hash_entries, + .mode = 0444, + .proc_handler = proc_udp_hash_entries, + }, + { + .procname = "udp_child_hash_entries", + .data = &init_net.ipv4.sysctl_udp_child_hash_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &udp_child_hash_entries_max, + }, + { .procname = "udp_rmem_min", .data = &init_net.ipv4.sysctl_udp_rmem_min, .maxlen = sizeof(init_net.ipv4.sysctl_udp_rmem_min), @@ -1384,6 +1426,47 @@ static struct ctl_table ipv4_net_table[] = { .extra1 = SYSCTL_ZERO, .extra2 = SYSCTL_TWO, }, + { + .procname = "tcp_plb_enabled", + .data = &init_net.ipv4.sysctl_tcp_plb_enabled, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_plb_idle_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_idle_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_suspend_rto_sec", + .data = &init_net.ipv4.sysctl_tcp_plb_suspend_rto_sec, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_plb_cong_thresh", + .data = &init_net.ipv4.sysctl_tcp_plb_cong_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_plb_max_cong_thresh, + }, { } }; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 54836a6b81d6..33f559f491c8 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -435,6 +435,7 @@ void tcp_init_sock(struct sock *sk) /* There's a bubble in the pipe until at least the first ACK. */ tp->app_limited = ~0U; + tp->rate_app_limited = 1; /* See draft-stevens-tcpca-spec-01 for discussion of the * initialization of these values. @@ -2000,7 +2001,7 @@ static int receive_fallback_to_copy(struct sock *sk, if (copy_address != zc->copybuf_address) return -EINVAL; - err = import_single_range(READ, (void __user *)copy_address, + err = import_single_range(ITER_DEST, (void __user *)copy_address, inq, &iov, &msg.msg_iter); if (err) return err; @@ -2034,7 +2035,7 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, if (copy_address != zc->copybuf_address) return -EINVAL; - err = import_single_range(READ, (void __user *)copy_address, + err = import_single_range(ITER_DEST, (void __user *)copy_address, copylen, &iov, &msg.msg_iter); if (err) return err; @@ -3114,8 +3115,7 @@ int tcp_disconnect(struct sock *sk, int flags) inet->inet_dport = 0; - if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) - inet_reset_saddr(sk); + inet_bhash2_reset_saddr(sk); sk->sk_shutdown = 0; sock_reset_flag(sk, SOCK_DONE); @@ -3176,8 +3176,10 @@ int tcp_disconnect(struct sock *sk, int flags) tp->sacked_out = 0; tp->tlp_high_seq = 0; tp->last_oow_ack_time = 0; + tp->plb_rehash = 0; /* There's a bubble in the pipe until at least the first ACK. */ tp->app_limited = ~0U; + tp->rate_app_limited = 1; tp->rack.mstamp = 0; tp->rack.advanced = 0; tp->rack.reo_wnd_steps = 1; @@ -3939,6 +3941,8 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_reord_seen = tp->reord_seen; info->tcpi_rcv_ooopack = tp->rcv_ooopack; info->tcpi_snd_wnd = tp->snd_wnd; + info->tcpi_rcv_wnd = tp->rcv_wnd; + info->tcpi_rehash = tp->plb_rehash + tp->timeout_rehash; info->tcpi_fastopen_client_fail = tp->fastopen_client_fail; unlock_sock_fast(sk, slow); } @@ -3973,6 +3977,7 @@ static size_t tcp_opt_stats_get_size(void) nla_total_size(sizeof(u32)) + /* TCP_NLA_BYTES_NOTSENT */ nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_EDT */ nla_total_size(sizeof(u8)) + /* TCP_NLA_TTL */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_REHASH */ 0; } @@ -4049,6 +4054,7 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk, nla_put_u8(stats, TCP_NLA_TTL, tcp_skb_ttl_or_hop_limit(ack_skb)); + nla_put_u32(stats, TCP_NLA_REHASH, tp->plb_rehash + tp->timeout_rehash); return stats; } @@ -4460,11 +4466,8 @@ bool tcp_alloc_md5sig_pool(void) if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) { mutex_lock(&tcp_md5sig_mutex); - if (!tcp_md5sig_pool_populated) { + if (!tcp_md5sig_pool_populated) __tcp_alloc_md5sig_pool(); - if (tcp_md5sig_pool_populated) - static_branch_inc(&tcp_md5_needed); - } mutex_unlock(&tcp_md5sig_mutex); } diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index 54eec33c6e1c..d2c470524e58 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -618,7 +618,7 @@ static void bbr_reset_probe_bw_mode(struct sock *sk) struct bbr *bbr = inet_csk_ca(sk); bbr->mode = BBR_PROBE_BW; - bbr->cycle_idx = CYCLE_LEN - 1 - prandom_u32_max(bbr_cycle_rand); + bbr->cycle_idx = CYCLE_LEN - 1 - get_random_u32_below(bbr_cycle_rand); bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ } diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index cf9c3e8f7ccb..cf26d65ca389 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -6,6 +6,7 @@ #include <linux/bpf.h> #include <linux/init.h> #include <linux/wait.h> +#include <linux/util_macros.h> #include <net/inet_common.h> #include <net/tls.h> @@ -45,8 +46,11 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, tmp->sg.end = i; if (apply) { apply_bytes -= size; - if (!apply_bytes) + if (!apply_bytes) { + if (sge->length) + sk_msg_iter_var_prev(i); break; + } } } while (i != msg->sg.end); @@ -131,10 +135,9 @@ static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, return ret; } -int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, - u32 bytes, int flags) +int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, + struct sk_msg *msg, u32 bytes, int flags) { - bool ingress = sk_msg_to_ingress(msg); struct sk_psock *psock = sk_psock_get(sk); int ret; @@ -276,10 +279,10 @@ msg_bytes_ready: static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, struct sk_msg *msg, int *copied, int flags) { - bool cork = false, enospc = sk_msg_full(msg); + bool cork = false, enospc = sk_msg_full(msg), redir_ingress; struct sock *sk_redir; u32 tosend, origsize, sent, delta = 0; - u32 eval = __SK_NONE; + u32 eval; int ret; more_data: @@ -310,6 +313,7 @@ more_data: tosend = msg->sg.size; if (psock->apply_bytes && psock->apply_bytes < tosend) tosend = psock->apply_bytes; + eval = __SK_NONE; switch (psock->eval) { case __SK_PASS: @@ -321,6 +325,7 @@ more_data: sk_msg_apply_bytes(psock, tosend); break; case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; sk_redir = psock->sk_redir; sk_msg_apply_bytes(psock, tosend); if (!psock->apply_bytes) { @@ -337,7 +342,8 @@ more_data: release_sock(sk); origsize = msg->sg.size; - ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); + ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + msg, tosend, flags); sent = origsize - msg->sg.size; if (eval == __SK_REDIRECT) @@ -634,10 +640,9 @@ EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); */ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) { - int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; struct proto *prot = newsk->sk_prot; - if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) + if (is_insidevar(prot, tcp_bpf_prots)) newsk->sk_prot = sk->sk_prot_creator; } #endif /* CONFIG_BPF_SYSCALL */ diff --git a/net/ipv4/tcp_dctcp.c b/net/ipv4/tcp_dctcp.c index 2a6c0dd665a4..e0a2ca7456ff 100644 --- a/net/ipv4/tcp_dctcp.c +++ b/net/ipv4/tcp_dctcp.c @@ -54,6 +54,7 @@ struct dctcp { u32 next_seq; u32 ce_state; u32 loss_cwnd; + struct tcp_plb_state plb; }; static unsigned int dctcp_shift_g __read_mostly = 4; /* g = 1/2^4 */ @@ -91,6 +92,8 @@ static void dctcp_init(struct sock *sk) ca->ce_state = 0; dctcp_reset(tp, ca); + tcp_plb_init(sk, &ca->plb); + return; } @@ -117,14 +120,28 @@ static void dctcp_update_alpha(struct sock *sk, u32 flags) /* Expired RTT */ if (!before(tp->snd_una, ca->next_seq)) { + u32 delivered = tp->delivered - ca->old_delivered; u32 delivered_ce = tp->delivered_ce - ca->old_delivered_ce; u32 alpha = ca->dctcp_alpha; + u32 ce_ratio = 0; + + if (delivered > 0) { + /* dctcp_alpha keeps EWMA of fraction of ECN marked + * packets. Because of EWMA smoothing, PLB reaction can + * be slow so we use ce_ratio which is an instantaneous + * measure of congestion. ce_ratio is the fraction of + * ECN marked packets in the previous RTT. + */ + if (delivered_ce > 0) + ce_ratio = (delivered_ce << TCP_PLB_SCALE) / delivered; + tcp_plb_update_state(sk, &ca->plb, (int)ce_ratio); + tcp_plb_check_rehash(sk, &ca->plb); + } /* alpha = (1 - g) * alpha + g * F */ alpha -= min_not_zero(alpha, alpha >> dctcp_shift_g); if (delivered_ce) { - u32 delivered = tp->delivered - ca->old_delivered; /* If dctcp_shift_g == 1, a 32bit value would overflow * after 8 M packets. @@ -172,8 +189,12 @@ static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state); break; case CA_EVENT_LOSS: + tcp_plb_update_state_upon_rto(sk, &ca->plb); dctcp_react_to_loss(sk); break; + case CA_EVENT_TX_START: + tcp_plb_check_rehash(sk, &ca->plb); /* Maybe rehash when inflight is 0 */ + break; default: /* Don't care for the rest. */ break; diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 0640453fce54..cc072d2cfcd8 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -3646,7 +3646,8 @@ static void tcp_send_challenge_ack(struct sock *sk) u32 half = (ack_limit + 1) >> 1; WRITE_ONCE(net->ipv4.tcp_challenge_timestamp, now); - WRITE_ONCE(net->ipv4.tcp_challenge_count, half + prandom_u32_max(ack_limit)); + WRITE_ONCE(net->ipv4.tcp_challenge_count, + get_random_u32_inclusive(half, ack_limit + half - 1)); } count = READ_ONCE(net->ipv4.tcp_challenge_count); if (count > 0) { @@ -4764,8 +4765,8 @@ static void tcp_ofo_queue(struct sock *sk) } } -static bool tcp_prune_ofo_queue(struct sock *sk); -static int tcp_prune_queue(struct sock *sk); +static bool tcp_prune_ofo_queue(struct sock *sk, const struct sk_buff *in_skb); +static int tcp_prune_queue(struct sock *sk, const struct sk_buff *in_skb); static int tcp_try_rmem_schedule(struct sock *sk, struct sk_buff *skb, unsigned int size) @@ -4773,11 +4774,11 @@ static int tcp_try_rmem_schedule(struct sock *sk, struct sk_buff *skb, if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf || !sk_rmem_schedule(sk, skb, size)) { - if (tcp_prune_queue(sk) < 0) + if (tcp_prune_queue(sk, skb) < 0) return -1; while (!sk_rmem_schedule(sk, skb, size)) { - if (!tcp_prune_ofo_queue(sk)) + if (!tcp_prune_ofo_queue(sk, skb)) return -1; } } @@ -5329,6 +5330,8 @@ new_range: * Clean the out-of-order queue to make room. * We drop high sequences packets to : * 1) Let a chance for holes to be filled. + * This means we do not drop packets from ooo queue if their sequence + * is before incoming packet sequence. * 2) not add too big latencies if thousands of packets sit there. * (But if application shrinks SO_RCVBUF, we could still end up * freeing whole queue here) @@ -5336,24 +5339,31 @@ new_range: * * Return true if queue has shrunk. */ -static bool tcp_prune_ofo_queue(struct sock *sk) +static bool tcp_prune_ofo_queue(struct sock *sk, const struct sk_buff *in_skb) { struct tcp_sock *tp = tcp_sk(sk); struct rb_node *node, *prev; + bool pruned = false; int goal; if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) return false; - NET_INC_STATS(sock_net(sk), LINUX_MIB_OFOPRUNED); goal = sk->sk_rcvbuf >> 3; node = &tp->ooo_last_skb->rbnode; + do { + struct sk_buff *skb = rb_to_skb(node); + + /* If incoming skb would land last in ofo queue, stop pruning. */ + if (after(TCP_SKB_CB(in_skb)->seq, TCP_SKB_CB(skb)->seq)) + break; + pruned = true; prev = rb_prev(node); rb_erase(node, &tp->out_of_order_queue); - goal -= rb_to_skb(node)->truesize; - tcp_drop_reason(sk, rb_to_skb(node), - SKB_DROP_REASON_TCP_OFO_QUEUE_PRUNE); + goal -= skb->truesize; + tcp_drop_reason(sk, skb, SKB_DROP_REASON_TCP_OFO_QUEUE_PRUNE); + tp->ooo_last_skb = rb_to_skb(prev); if (!prev || goal <= 0) { if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf && !tcp_under_memory_pressure(sk)) @@ -5362,16 +5372,18 @@ static bool tcp_prune_ofo_queue(struct sock *sk) } node = prev; } while (node); - tp->ooo_last_skb = rb_to_skb(prev); - /* Reset SACK state. A conforming SACK implementation will - * do the same at a timeout based retransmit. When a connection - * is in a sad state like this, we care only about integrity - * of the connection not performance. - */ - if (tp->rx_opt.sack_ok) - tcp_sack_reset(&tp->rx_opt); - return true; + if (pruned) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_OFOPRUNED); + /* Reset SACK state. A conforming SACK implementation will + * do the same at a timeout based retransmit. When a connection + * is in a sad state like this, we care only about integrity + * of the connection not performance. + */ + if (tp->rx_opt.sack_ok) + tcp_sack_reset(&tp->rx_opt); + } + return pruned; } /* Reduce allocated memory if we can, trying to get @@ -5381,7 +5393,7 @@ static bool tcp_prune_ofo_queue(struct sock *sk) * until the socket owning process reads some of the data * to stabilize the situation. */ -static int tcp_prune_queue(struct sock *sk) +static int tcp_prune_queue(struct sock *sk, const struct sk_buff *in_skb) { struct tcp_sock *tp = tcp_sk(sk); @@ -5408,7 +5420,7 @@ static int tcp_prune_queue(struct sock *sk) /* Collapsing did not help, destructive actions follow. * This must not ever occur. */ - tcp_prune_ofo_queue(sk); + tcp_prune_ofo_queue(sk, in_skb); if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) return 0; @@ -6830,10 +6842,18 @@ static bool tcp_syn_flood_action(const struct sock *sk, const char *proto) #endif __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPREQQFULLDROP); - if (!queue->synflood_warned && syncookies != 2 && - xchg(&queue->synflood_warned, 1) == 0) - net_info_ratelimited("%s: Possible SYN flooding on port %d. %s. Check SNMP counters.\n", - proto, sk->sk_num, msg); + if (!READ_ONCE(queue->synflood_warned) && syncookies != 2 && + xchg(&queue->synflood_warned, 1) == 0) { + if (IS_ENABLED(CONFIG_IPV6) && sk->sk_family == AF_INET6) { + net_info_ratelimited("%s: Possible SYN flooding on port [%pI6c]:%u. %s.\n", + proto, inet6_rcv_saddr(sk), + sk->sk_num, msg); + } else { + net_info_ratelimited("%s: Possible SYN flooding on port %pI4:%u. %s.\n", + proto, &sk->sk_rcv_saddr, + sk->sk_num, msg); + } + } return want_cookie; } diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 87d440f47a70..8320d0ecb13a 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -199,15 +199,14 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, /* This will initiate an outgoing connection. */ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { - struct inet_bind_hashbucket *prev_addr_hashbucket = NULL; struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; struct inet_timewait_death_row *tcp_death_row; - __be32 daddr, nexthop, prev_sk_rcv_saddr; struct inet_sock *inet = inet_sk(sk); struct tcp_sock *tp = tcp_sk(sk); struct ip_options_rcu *inet_opt; struct net *net = sock_net(sk); __be16 orig_sport, orig_dport; + __be32 daddr, nexthop; struct flowi4 *fl4; struct rtable *rt; int err; @@ -251,24 +250,13 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; if (!inet->inet_saddr) { - if (inet_csk(sk)->icsk_bind2_hash) { - prev_addr_hashbucket = inet_bhashfn_portaddr(tcp_death_row->hashinfo, - sk, net, inet->inet_num); - prev_sk_rcv_saddr = sk->sk_rcv_saddr; - } - inet->inet_saddr = fl4->saddr; - } - - sk_rcv_saddr_set(sk, inet->inet_saddr); - - if (prev_addr_hashbucket) { - err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); + err = inet_bhash2_update_saddr(sk, &fl4->saddr, AF_INET); if (err) { - inet->inet_saddr = 0; - sk_rcv_saddr_set(sk, prev_sk_rcv_saddr); ip_rt_put(rt); return err; } + } else { + sk_rcv_saddr_set(sk, inet->inet_saddr); } if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) { @@ -343,6 +331,7 @@ failure: * if necessary. */ tcp_set_state(sk, TCP_CLOSE); + inet_bhash2_reset_saddr(sk); ip_rt_put(rt); sk->sk_route_caps = 0; inet->inet_dport = 0; @@ -1064,7 +1053,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req) * We need to maintain these in the sk structure. */ -DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); +DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ); EXPORT_SYMBOL(tcp_md5_needed); static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new) @@ -1172,10 +1161,25 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk, } EXPORT_SYMBOL(tcp_v4_md5_lookup); +static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_info *md5sig; + + md5sig = kmalloc(sizeof(*md5sig), gfp); + if (!md5sig) + return -ENOMEM; + + sk_gso_disable(sk); + INIT_HLIST_HEAD(&md5sig->head); + rcu_assign_pointer(tp->md5sig_info, md5sig); + return 0; +} + /* This can be called on a newly created socket, from other files */ -int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, - int family, u8 prefixlen, int l3index, u8 flags, - const u8 *newkey, u8 newkeylen, gfp_t gfp) +static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, u8 flags, + const u8 *newkey, u8 newkeylen, gfp_t gfp) { /* Add Key to the list */ struct tcp_md5sig_key *key; @@ -1204,15 +1208,6 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); - if (!md5sig) { - md5sig = kmalloc(sizeof(*md5sig), gfp); - if (!md5sig) - return -ENOMEM; - - sk_gso_disable(sk); - INIT_HLIST_HEAD(&md5sig->head); - rcu_assign_pointer(tp->md5sig_info, md5sig); - } key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO); if (!key) @@ -1234,8 +1229,59 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, hlist_add_head_rcu(&key->node, &md5sig->head); return 0; } + +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, u8 flags, + const u8 *newkey, u8 newkeylen) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { + if (tcp_md5sig_info_add(sk, GFP_KERNEL)) + return -ENOMEM; + + if (!static_branch_inc(&tcp_md5_needed.key)) { + struct tcp_md5sig_info *md5sig; + + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); + rcu_assign_pointer(tp->md5sig_info, NULL); + kfree_rcu(md5sig, rcu); + return -EUSERS; + } + } + + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags, + newkey, newkeylen, GFP_KERNEL); +} EXPORT_SYMBOL(tcp_md5_do_add); +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, + struct tcp_md5sig_key *key) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { + if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC))) + return -ENOMEM; + + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) { + struct tcp_md5sig_info *md5sig; + + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); + net_warn_ratelimited("Too many TCP-MD5 keys in the system\n"); + rcu_assign_pointer(tp->md5sig_info, NULL); + kfree_rcu(md5sig, rcu); + return -EUSERS; + } + } + + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, + key->flags, key->key, key->keylen, + sk_gfp_mask(sk, GFP_ATOMIC)); +} +EXPORT_SYMBOL(tcp_md5_key_copy); + int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, u8 prefixlen, int l3index, u8 flags) { @@ -1322,7 +1368,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, return -EINVAL; return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags, - cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); + cmd.tcpm_key, cmd.tcpm_keylen); } static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, @@ -1573,14 +1619,8 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, addr = (union tcp_md5_addr *)&newinet->inet_daddr; key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); if (key) { - /* - * We're using one, so create a matching key - * on the newsk structure. If we fail to get - * memory, then we end up not copying the key - * across. Shucks. - */ - tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags, - key->key, key->keylen, GFP_ATOMIC); + if (tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key)) + goto put_and_exit; sk_gso_disable(newsk); } #endif @@ -2272,6 +2312,7 @@ void tcp_v4_destroy_sock(struct sock *sk) tcp_clear_md5_list(sk); kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu); tp->md5sig_info = NULL; + static_branch_slow_dec_deferred(&tcp_md5_needed); } #endif @@ -2480,7 +2521,6 @@ static void *tcp_seek_last_pos(struct seq_file *seq) case TCP_SEQ_STATE_LISTENING: if (st->bucket > hinfo->lhash2_mask) break; - st->state = TCP_SEQ_STATE_LISTENING; rc = listening_get_first(seq); while (offset-- && rc && bucket == st->bucket) rc = listening_get_next(seq, rc); @@ -3218,6 +3258,14 @@ static int __net_init tcp_sk_init(struct net *net) net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 0; atomic_set(&net->ipv4.tfo_active_disable_times, 0); + /* Set default values for PLB */ + net->ipv4.sysctl_tcp_plb_enabled = 0; /* Disabled by default */ + net->ipv4.sysctl_tcp_plb_idle_rehash_rounds = 3; + net->ipv4.sysctl_tcp_plb_rehash_rounds = 12; + net->ipv4.sysctl_tcp_plb_suspend_rto_sec = 60; + /* Default congestion threshold for PLB to mark a round is 50% */ + net->ipv4.sysctl_tcp_plb_cong_thresh = (1 << TCP_PLB_SCALE) / 2; + /* Reno is always built in */ if (!net_eq(net, &init_net) && bpf_try_module_get(init_net.ipv4.tcp_congestion_control, diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index c375f603a16c..e002f2e1d4f2 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -240,6 +240,40 @@ kill: } EXPORT_SYMBOL(tcp_timewait_state_process); +static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw) +{ +#ifdef CONFIG_TCP_MD5SIG + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + + /* + * The timewait bucket does not have the key DB from the + * sock structure. We just make a quick copy of the + * md5 key being used (if indeed we are using one) + * so the timewait ack generating code has the key. + */ + tcptw->tw_md5_key = NULL; + if (!static_branch_unlikely(&tcp_md5_needed.key)) + return; + + key = tp->af_specific->md5_lookup(sk, sk); + if (key) { + tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); + if (!tcptw->tw_md5_key) + return; + if (!tcp_alloc_md5sig_pool()) + goto out_free; + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) + goto out_free; + } + return; +out_free: + WARN_ON_ONCE(1); + kfree(tcptw->tw_md5_key); + tcptw->tw_md5_key = NULL; +#endif +} + /* * Move a socket to time-wait or dead fin-wait-2 state. */ @@ -282,26 +316,7 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) } #endif -#ifdef CONFIG_TCP_MD5SIG - /* - * The timewait bucket does not have the key DB from the - * sock structure. We just make a quick copy of the - * md5 key being used (if indeed we are using one) - * so the timewait ack generating code has the key. - */ - do { - tcptw->tw_md5_key = NULL; - if (static_branch_unlikely(&tcp_md5_needed)) { - struct tcp_md5sig_key *key; - - key = tp->af_specific->md5_lookup(sk, sk); - if (key) { - tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); - BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); - } - } - } while (0); -#endif + tcp_time_wait_init(sk, tcptw); /* Get the TIME_WAIT timeout firing. */ if (timeo < rto) @@ -337,11 +352,13 @@ EXPORT_SYMBOL(tcp_time_wait); void tcp_twsk_destructor(struct sock *sk) { #ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed)) { + if (static_branch_unlikely(&tcp_md5_needed.key)) { struct tcp_timewait_sock *twsk = tcp_twsk(sk); - if (twsk->tw_md5_key) + if (twsk->tw_md5_key) { kfree_rcu(twsk->tw_md5_key, rcu); + static_branch_slow_dec_deferred(&tcp_md5_needed); + } } #endif } diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index c69f4d966024..71d01cf3c13e 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -766,7 +766,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, *md5 = NULL; #ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed) && + if (static_branch_unlikely(&tcp_md5_needed.key) && rcu_access_pointer(tp->md5sig_info)) { *md5 = tp->af_specific->md5_lookup(sk, sk); if (*md5) { @@ -922,7 +922,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb *md5 = NULL; #ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed) && + if (static_branch_unlikely(&tcp_md5_needed.key) && rcu_access_pointer(tp->md5sig_info)) { *md5 = tp->af_specific->md5_lookup(sk, sk); if (*md5) { @@ -1077,15 +1077,15 @@ static void tcp_tasklet_func(struct tasklet_struct *t) */ void tcp_release_cb(struct sock *sk) { - unsigned long flags, nflags; + unsigned long flags = smp_load_acquire(&sk->sk_tsq_flags); + unsigned long nflags; /* perform an atomic operation only if at least one flag is set */ do { - flags = sk->sk_tsq_flags; if (!(flags & TCP_DEFERRED_ALL)) return; nflags = flags & ~TCP_DEFERRED_ALL; - } while (cmpxchg(&sk->sk_tsq_flags, flags, nflags) != flags); + } while (!try_cmpxchg(&sk->sk_tsq_flags, &flags, nflags)); if (flags & TCPF_TSQ_DEFERRED) { tcp_tsq_write(sk); @@ -1139,6 +1139,8 @@ void tcp_wfree(struct sk_buff *skb) struct sock *sk = skb->sk; struct tcp_sock *tp = tcp_sk(sk); unsigned long flags, nval, oval; + struct tsq_tasklet *tsq; + bool empty; /* Keep one reference on sk_wmem_alloc. * Will be released by sk_free() from here or tcp_tasklet_func() @@ -1155,28 +1157,23 @@ void tcp_wfree(struct sk_buff *skb) if (refcount_read(&sk->sk_wmem_alloc) >= SKB_TRUESIZE(1) && this_cpu_ksoftirqd() == current) goto out; - for (oval = READ_ONCE(sk->sk_tsq_flags);; oval = nval) { - struct tsq_tasklet *tsq; - bool empty; - + oval = smp_load_acquire(&sk->sk_tsq_flags); + do { if (!(oval & TSQF_THROTTLED) || (oval & TSQF_QUEUED)) goto out; nval = (oval & ~TSQF_THROTTLED) | TSQF_QUEUED; - nval = cmpxchg(&sk->sk_tsq_flags, oval, nval); - if (nval != oval) - continue; + } while (!try_cmpxchg(&sk->sk_tsq_flags, &oval, nval)); - /* queue this socket to tasklet queue */ - local_irq_save(flags); - tsq = this_cpu_ptr(&tsq_tasklet); - empty = list_empty(&tsq->head); - list_add(&tp->tsq_node, &tsq->head); - if (empty) - tasklet_schedule(&tsq->tasklet); - local_irq_restore(flags); - return; - } + /* queue this socket to tasklet queue */ + local_irq_save(flags); + tsq = this_cpu_ptr(&tsq_tasklet); + empty = list_empty(&tsq->head); + list_add(&tp->tsq_node, &tsq->head); + if (empty) + tasklet_schedule(&tsq->tasklet); + local_irq_restore(flags); + return; out: sk_free(sk); } diff --git a/net/ipv4/tcp_plb.c b/net/ipv4/tcp_plb.c new file mode 100644 index 000000000000..4bcf7eff95e3 --- /dev/null +++ b/net/ipv4/tcp_plb.c @@ -0,0 +1,109 @@ +/* Protective Load Balancing (PLB) + * + * PLB was designed to reduce link load imbalance across datacenter + * switches. PLB is a host-based optimization; it leverages congestion + * signals from the transport layer to randomly change the path of the + * connection experiencing sustained congestion. PLB prefers to repath + * after idle periods to minimize packet reordering. It repaths by + * changing the IPv6 Flow Label on the packets of a connection, which + * datacenter switches include as part of ECMP/WCMP hashing. + * + * PLB is described in detail in: + * + * Mubashir Adnan Qureshi, Yuchung Cheng, Qianwen Yin, Qiaobin Fu, + * Gautam Kumar, Masoud Moshref, Junhua Yan, Van Jacobson, + * David Wetherall,Abdul Kabbani: + * "PLB: Congestion Signals are Simple and Effective for + * Network Load Balancing" + * In ACM SIGCOMM 2022, Amsterdam Netherlands. + * + */ + +#include <net/tcp.h> + +/* Called once per round-trip to update PLB state for a connection. */ +void tcp_plb_update_state(const struct sock *sk, struct tcp_plb_state *plb, + const int cong_ratio) +{ + struct net *net = sock_net(sk); + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + if (cong_ratio >= 0) { + if (cong_ratio < READ_ONCE(net->ipv4.sysctl_tcp_plb_cong_thresh)) + plb->consec_cong_rounds = 0; + else if (plb->consec_cong_rounds < + READ_ONCE(net->ipv4.sysctl_tcp_plb_rehash_rounds)) + plb->consec_cong_rounds++; + } +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state); + +/* Check whether recent congestion has been persistent enough to warrant + * a load balancing decision that switches the connection to another path. + */ +void tcp_plb_check_rehash(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + u32 max_suspend; + bool forced_rehash = false, idle_rehash = false; + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + forced_rehash = plb->consec_cong_rounds >= + READ_ONCE(net->ipv4.sysctl_tcp_plb_rehash_rounds); + /* If sender goes idle then we check whether to rehash. */ + idle_rehash = READ_ONCE(net->ipv4.sysctl_tcp_plb_idle_rehash_rounds) && + !tcp_sk(sk)->packets_out && + plb->consec_cong_rounds >= + READ_ONCE(net->ipv4.sysctl_tcp_plb_idle_rehash_rounds); + + if (!forced_rehash && !idle_rehash) + return; + + /* Note that tcp_jiffies32 can wrap; we detect wraps by checking for + * cases where the max suspension end is before the actual suspension + * end. We clear pause_until to 0 to indicate there is no recent + * RTO event that constrains PLB rehashing. + */ + max_suspend = 2 * READ_ONCE(net->ipv4.sysctl_tcp_plb_suspend_rto_sec) * HZ; + if (plb->pause_until && + (!before(tcp_jiffies32, plb->pause_until) || + before(tcp_jiffies32 + max_suspend, plb->pause_until))) + plb->pause_until = 0; + + if (plb->pause_until) + return; + + sk_rethink_txhash(sk); + plb->consec_cong_rounds = 0; + tcp_sk(sk)->plb_rehash++; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPPLBREHASH); +} +EXPORT_SYMBOL_GPL(tcp_plb_check_rehash); + +/* Upon RTO, disallow load balancing for a while, to avoid having load + * balancing decisions switch traffic to a black-holed path that was + * previously avoided with a sk_rethink_txhash() call at RTO time. + */ +void tcp_plb_update_state_upon_rto(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + u32 pause; + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + pause = READ_ONCE(net->ipv4.sysctl_tcp_plb_suspend_rto_sec) * HZ; + pause += get_random_u32_below(pause); + plb->pause_until = tcp_jiffies32 + pause; + + /* Reset PLB state upon RTO, since an RTO causes a sk_rethink_txhash() call + * that may switch this connection to a path with completely different + * congestion characteristics. + */ + plb->consec_cong_rounds = 0; +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state_upon_rto); diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index 9ae50b1bd844..2aa442128630 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -139,6 +139,10 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) if (sk->sk_socket) clear_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + err = -ENOTCONN; + if (!ulp_ops->clone && sk->sk_state == TCP_LISTEN) + goto out_err; + err = ulp_ops->init(sk); if (err) goto out_err; diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 6a320a614e54..9592fe3e444a 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -129,7 +129,12 @@ DEFINE_PER_CPU(int, udp_memory_per_cpu_fw_alloc); EXPORT_PER_CPU_SYMBOL_GPL(udp_memory_per_cpu_fw_alloc); #define MAX_UDP_PORTS 65536 -#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN) +#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN_PERNET) + +static struct udp_table *udp_get_table_prot(struct sock *sk) +{ + return sk->sk_prot->h.udp_table ? : sock_net(sk)->ipv4.udp_table; +} static int udp_lib_lport_inuse(struct net *net, __u16 num, const struct udp_hslot *hslot, @@ -232,16 +237,16 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) int udp_lib_get_port(struct sock *sk, unsigned short snum, unsigned int hash2_nulladdr) { + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2; - struct udp_table *udptable = sk->sk_prot->h.udp_table; - int error = 1; struct net *net = sock_net(sk); + int error = -EADDRINUSE; if (!snum) { + DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); + unsigned short first, last; int low, high, remaining; unsigned int rand; - unsigned short first, last; - DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); inet_get_local_port_range(net, &low, &high); remaining = (high - low) + 1; @@ -467,7 +472,7 @@ static struct sock *udp4_lookup_run_bpf(struct net *net, struct sock *sk, *reuse_sk; bool no_reuseport; - if (udptable != &udp_table) + if (udptable != net->ipv4.udp_table) return NULL; /* only UDP is supported */ no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP, saddr, sport, @@ -548,10 +553,11 @@ struct sock *udp4_lib_lookup_skb(const struct sk_buff *skb, __be16 sport, __be16 dport) { const struct iphdr *iph = ip_hdr(skb); + struct net *net = dev_net(skb->dev); - return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, + return __udp4_lib_lookup(net, iph->saddr, sport, iph->daddr, dport, inet_iif(skb), - inet_sdif(skb), &udp_table, NULL); + inet_sdif(skb), net->ipv4.udp_table, NULL); } /* Must be called under rcu_read_lock(). @@ -564,7 +570,7 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, struct sock *sk; sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport, - dif, 0, &udp_table, NULL); + dif, 0, net->ipv4.udp_table, NULL); if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) sk = NULL; return sk; @@ -784,7 +790,8 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) if (tunnel) { /* ...not for tunnels though: we don't have a sending socket */ if (udp_sk(sk)->encap_err_rcv) - udp_sk(sk)->encap_err_rcv(sk, skb, iph->ihl << 2); + udp_sk(sk)->encap_err_rcv(sk, skb, err, uh->dest, info, + (u8 *)(uh+1)); goto out; } if (!inet->recverr) { @@ -801,7 +808,7 @@ out: int udp_err(struct sk_buff *skb, u32 info) { - return __udp4_lib_err(skb, info, &udp_table); + return __udp4_lib_err(skb, info, dev_net(skb->dev)->ipv4.udp_table); } /* @@ -1448,7 +1455,7 @@ static void udp_rmem_release(struct sock *sk, int size, int partial, if (likely(partial)) { up->forward_deficit += size; size = up->forward_deficit; - if (size < (sk->sk_rcvbuf >> 2) && + if (size < READ_ONCE(up->forward_threshold) && !skb_queue_empty(&up->reader_queue)) return; } else { @@ -1622,7 +1629,7 @@ static void udp_destruct_sock(struct sock *sk) int udp_init_sock(struct sock *sk) { - skb_queue_head_init(&udp_sk(sk)->reader_queue); + udp_lib_init_sock(sk); sk->sk_destruct = udp_destruct_sock; set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); return 0; @@ -1998,7 +2005,7 @@ EXPORT_SYMBOL(udp_disconnect); void udp_lib_unhash(struct sock *sk) { if (sk_hashed(sk)) { - struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2; hslot = udp_hashslot(udptable, sock_net(sk), @@ -2029,7 +2036,7 @@ EXPORT_SYMBOL(udp_lib_unhash); void udp_lib_rehash(struct sock *sk, u16 newhash) { if (sk_hashed(sk)) { - struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2, *nhslot2; hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); @@ -2518,10 +2525,14 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, __be16 rmt_port, __be32 rmt_addr, int dif, int sdif) { - struct sock *sk, *result; + struct udp_table *udptable = net->ipv4.udp_table; unsigned short hnum = ntohs(loc_port); - unsigned int slot = udp_hashfn(net, hnum, udp_table.mask); - struct udp_hslot *hslot = &udp_table.hash[slot]; + struct sock *sk, *result; + struct udp_hslot *hslot; + unsigned int slot; + + slot = udp_hashfn(net, hnum, udptable->mask); + hslot = &udptable->hash[slot]; /* Do not bother scanning a too big list */ if (hslot->count > 10) @@ -2549,14 +2560,19 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, __be16 rmt_port, __be32 rmt_addr, int dif, int sdif) { - unsigned short hnum = ntohs(loc_port); - unsigned int hash2 = ipv4_portaddr_hash(net, loc_addr, hnum); - unsigned int slot2 = hash2 & udp_table.mask; - struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; + struct udp_table *udptable = net->ipv4.udp_table; INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr); - const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); + unsigned short hnum = ntohs(loc_port); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + __portpair ports; struct sock *sk; + hash2 = ipv4_portaddr_hash(net, loc_addr, hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + ports = INET_COMBINED_PORTS(rmt_port, hnum); + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { if (inet_match(net, sk, acookie, ports, dif, sdif)) return sk; @@ -2636,7 +2652,7 @@ int udp_v4_early_demux(struct sk_buff *skb) int udp_rcv(struct sk_buff *skb) { - return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP); + return __udp4_lib_rcv(skb, dev_net(skb->dev)->ipv4.udp_table, IPPROTO_UDP); } void udp_destroy_sock(struct sock *sk) @@ -2672,6 +2688,18 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, int err = 0; int is_udplite = IS_UDPLITE(sk); + if (level == SOL_SOCKET) { + err = sk_setsockopt(sk, level, optname, optval, optlen); + + if (optname == SO_RCVBUF || optname == SO_RCVBUFFORCE) { + sockopt_lock_sock(sk); + /* paired with READ_ONCE in udp_rmem_release() */ + WRITE_ONCE(up->forward_threshold, sk->sk_rcvbuf >> 2); + sockopt_release_sock(sk); + } + return err; + } + if (optlen < sizeof(int)) return -EINVAL; @@ -2785,7 +2813,7 @@ EXPORT_SYMBOL(udp_lib_setsockopt); int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, unsigned int optlen) { - if (level == SOL_UDP || level == SOL_UDPLITE) + if (level == SOL_UDP || level == SOL_UDPLITE || level == SOL_SOCKET) return udp_lib_setsockopt(sk, level, optname, optval, optlen, udp_push_pending_frames); @@ -2947,7 +2975,7 @@ struct proto udp_prot = { .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), .obj_size = sizeof(struct udp_sock), - .h.udp_table = &udp_table, + .h.udp_table = NULL, .diag_destroy = udp_abort, }; EXPORT_SYMBOL(udp_prot); @@ -2955,21 +2983,30 @@ EXPORT_SYMBOL(udp_prot); /* ------------------------------------------------------------------------ */ #ifdef CONFIG_PROC_FS +static struct udp_table *udp_get_table_afinfo(struct udp_seq_afinfo *afinfo, + struct net *net) +{ + return afinfo->udp_table ? : net->ipv4.udp_table; +} + static struct sock *udp_get_first(struct seq_file *seq, int start) { - struct sock *sk; - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; struct net *net = seq_file_net(seq); + struct udp_seq_afinfo *afinfo; + struct udp_table *udptable; + struct sock *sk; if (state->bpf_seq_afinfo) afinfo = state->bpf_seq_afinfo; else afinfo = pde_data(file_inode(seq->file)); - for (state->bucket = start; state->bucket <= afinfo->udp_table->mask; + udptable = udp_get_table_afinfo(afinfo, net); + + for (state->bucket = start; state->bucket <= udptable->mask; ++state->bucket) { - struct udp_hslot *hslot = &afinfo->udp_table->hash[state->bucket]; + struct udp_hslot *hslot = &udptable->hash[state->bucket]; if (hlist_empty(&hslot->head)) continue; @@ -2991,9 +3028,10 @@ found: static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) { - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; struct net *net = seq_file_net(seq); + struct udp_seq_afinfo *afinfo; + struct udp_table *udptable; if (state->bpf_seq_afinfo) afinfo = state->bpf_seq_afinfo; @@ -3007,8 +3045,11 @@ static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) sk->sk_family != afinfo->family))); if (!sk) { - if (state->bucket <= afinfo->udp_table->mask) - spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); + udptable = udp_get_table_afinfo(afinfo, net); + + if (state->bucket <= udptable->mask) + spin_unlock_bh(&udptable->hash[state->bucket].lock); + return udp_get_first(seq, state->bucket + 1); } return sk; @@ -3049,16 +3090,19 @@ EXPORT_SYMBOL(udp_seq_next); void udp_seq_stop(struct seq_file *seq, void *v) { - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; + struct udp_seq_afinfo *afinfo; + struct udp_table *udptable; if (state->bpf_seq_afinfo) afinfo = state->bpf_seq_afinfo; else afinfo = pde_data(file_inode(seq->file)); - if (state->bucket <= afinfo->udp_table->mask) - spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); + udptable = udp_get_table_afinfo(afinfo, seq_file_net(seq)); + + if (state->bucket <= udptable->mask) + spin_unlock_bh(&udptable->hash[state->bucket].lock); } EXPORT_SYMBOL(udp_seq_stop); @@ -3171,7 +3215,7 @@ EXPORT_SYMBOL(udp_seq_ops); static struct udp_seq_afinfo udp4_seq_afinfo = { .family = AF_INET, - .udp_table = &udp_table, + .udp_table = NULL, }; static int __net_init udp4_proc_init_net(struct net *net) @@ -3233,7 +3277,7 @@ void __init udp_table_init(struct udp_table *table, const char *name) &table->log, &table->mask, UDP_HTABLE_SIZE_MIN, - 64 * 1024); + UDP_HTABLE_SIZE_MAX); table->hash2 = table->hash + (table->mask + 1); for (i = 0; i <= table->mask; i++) { @@ -3258,7 +3302,7 @@ u32 udp_flow_hashrnd(void) } EXPORT_SYMBOL(udp_flow_hashrnd); -static int __net_init udp_sysctl_init(struct net *net) +static void __net_init udp_sysctl_init(struct net *net) { net->ipv4.sysctl_udp_rmem_min = PAGE_SIZE; net->ipv4.sysctl_udp_wmem_min = PAGE_SIZE; @@ -3266,12 +3310,103 @@ static int __net_init udp_sysctl_init(struct net *net) #ifdef CONFIG_NET_L3_MASTER_DEV net->ipv4.sysctl_udp_l3mdev_accept = 0; #endif +} + +static struct udp_table __net_init *udp_pernet_table_alloc(unsigned int hash_entries) +{ + struct udp_table *udptable; + int i; + + udptable = kmalloc(sizeof(*udptable), GFP_KERNEL); + if (!udptable) + goto out; + + udptable->hash = vmalloc_huge(hash_entries * 2 * sizeof(struct udp_hslot), + GFP_KERNEL_ACCOUNT); + if (!udptable->hash) + goto free_table; + + udptable->hash2 = udptable->hash + hash_entries; + udptable->mask = hash_entries - 1; + udptable->log = ilog2(hash_entries); + + for (i = 0; i < hash_entries; i++) { + INIT_HLIST_HEAD(&udptable->hash[i].head); + udptable->hash[i].count = 0; + spin_lock_init(&udptable->hash[i].lock); + + INIT_HLIST_HEAD(&udptable->hash2[i].head); + udptable->hash2[i].count = 0; + spin_lock_init(&udptable->hash2[i].lock); + } + + return udptable; + +free_table: + kfree(udptable); +out: + return NULL; +} + +static void __net_exit udp_pernet_table_free(struct net *net) +{ + struct udp_table *udptable = net->ipv4.udp_table; + + if (udptable == &udp_table) + return; + + kvfree(udptable->hash); + kfree(udptable); +} + +static void __net_init udp_set_table(struct net *net) +{ + struct udp_table *udptable; + unsigned int hash_entries; + struct net *old_net; + + if (net_eq(net, &init_net)) + goto fallback; + + old_net = current->nsproxy->net_ns; + hash_entries = READ_ONCE(old_net->ipv4.sysctl_udp_child_hash_entries); + if (!hash_entries) + goto fallback; + + /* Set min to keep the bitmap on stack in udp_lib_get_port() */ + if (hash_entries < UDP_HTABLE_SIZE_MIN_PERNET) + hash_entries = UDP_HTABLE_SIZE_MIN_PERNET; + else + hash_entries = roundup_pow_of_two(hash_entries); + + udptable = udp_pernet_table_alloc(hash_entries); + if (udptable) { + net->ipv4.udp_table = udptable; + } else { + pr_warn("Failed to allocate UDP hash table (entries: %u) " + "for a netns, fallback to the global one\n", + hash_entries); +fallback: + net->ipv4.udp_table = &udp_table; + } +} + +static int __net_init udp_pernet_init(struct net *net) +{ + udp_sysctl_init(net); + udp_set_table(net); return 0; } +static void __net_exit udp_pernet_exit(struct net *net) +{ + udp_pernet_table_free(net); +} + static struct pernet_operations __net_initdata udp_sysctl_ops = { - .init = udp_sysctl_init, + .init = udp_pernet_init, + .exit = udp_pernet_exit, }; #if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) @@ -3289,7 +3424,7 @@ static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux) return -ENOMEM; afinfo->family = AF_UNSPEC; - afinfo->udp_table = &udp_table; + afinfo->udp_table = NULL; st->bpf_seq_afinfo = afinfo; ret = bpf_iter_init_seq_net(priv_data, aux); if (ret) diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c index 1ed8c4d78e5c..de3f2d31f510 100644 --- a/net/ipv4/udp_diag.c +++ b/net/ipv4/udp_diag.c @@ -147,13 +147,13 @@ done: static void udp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { - udp_dump(&udp_table, skb, cb, r); + udp_dump(sock_net(cb->skb->sk)->ipv4.udp_table, skb, cb, r); } static int udp_diag_dump_one(struct netlink_callback *cb, const struct inet_diag_req_v2 *req) { - return udp_dump_one(&udp_table, cb, req); + return udp_dump_one(sock_net(cb->skb->sk)->ipv4.udp_table, cb, req); } static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, @@ -225,7 +225,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, static int udp_diag_destroy(struct sk_buff *in_skb, const struct inet_diag_req_v2 *req) { - return __udp_diag_destroy(in_skb, req, &udp_table); + return __udp_diag_destroy(in_skb, req, sock_net(in_skb->sk)->ipv4.udp_table); } static int udplite_diag_destroy(struct sk_buff *in_skb, diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c index 6d1a4bec2614..1f01e15ca24f 100644 --- a/net/ipv4/udp_offload.c +++ b/net/ipv4/udp_offload.c @@ -387,7 +387,8 @@ static struct sk_buff *udp4_ufo_fragment(struct sk_buff *skb, if (!pskb_may_pull(skb, sizeof(struct udphdr))) goto out; - if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) + if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4 && + !skb_gso_ok(skb, features | NETIF_F_GSO_ROBUST)) return __udp_gso_segment(skb, features, false); mss = skb_shinfo(skb)->gso_size; @@ -600,10 +601,11 @@ static struct sock *udp4_gro_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport) { const struct iphdr *iph = skb_gro_network_header(skb); + struct net *net = dev_net(skb->dev); - return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, + return __udp4_lib_lookup(net, iph->saddr, sport, iph->daddr, dport, inet_iif(skb), - inet_sdif(skb), &udp_table, NULL); + inet_sdif(skb), net->ipv4.udp_table, NULL); } INDIRECT_CALLABLE_SCOPE diff --git a/net/ipv4/udp_tunnel_core.c b/net/ipv4/udp_tunnel_core.c index 8242c8947340..5f8104cf082d 100644 --- a/net/ipv4/udp_tunnel_core.c +++ b/net/ipv4/udp_tunnel_core.c @@ -176,6 +176,7 @@ EXPORT_SYMBOL_GPL(udp_tunnel_xmit_skb); void udp_tunnel_sock_release(struct socket *sock) { rcu_assign_sk_user_data(sock->sk, NULL); + synchronize_rcu(); kernel_sock_shutdown(sock, SHUT_RDWR); sock_release(sock); } diff --git a/net/ipv4/udp_tunnel_nic.c b/net/ipv4/udp_tunnel_nic.c index bc3a043a5d5c..029219749785 100644 --- a/net/ipv4/udp_tunnel_nic.c +++ b/net/ipv4/udp_tunnel_nic.c @@ -624,6 +624,8 @@ __udp_tunnel_nic_dump_write(struct net_device *dev, unsigned int table, continue; nest = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY); + if (!nest) + return -EMSGSIZE; if (nla_put_be16(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT, utn->entries[table][j].port) || diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c index 9c3f5202a97b..faa47f9ea73a 100644 --- a/net/ipv6/addrconf.c +++ b/net/ipv6/addrconf.c @@ -104,7 +104,7 @@ static inline u32 cstamp_delta(unsigned long cstamp) static inline s32 rfc3315_s14_backoff_init(s32 irt) { /* multiply 'initial retransmission time' by 0.9 .. 1.1 */ - u64 tmp = (900000 + prandom_u32_max(200001)) * (u64)irt; + u64 tmp = get_random_u32_inclusive(900000, 1100000) * (u64)irt; do_div(tmp, 1000000); return (s32)tmp; } @@ -112,11 +112,11 @@ static inline s32 rfc3315_s14_backoff_init(s32 irt) static inline s32 rfc3315_s14_backoff_update(s32 rt, s32 mrt) { /* multiply 'retransmission timeout' by 1.9 .. 2.1 */ - u64 tmp = (1900000 + prandom_u32_max(200001)) * (u64)rt; + u64 tmp = get_random_u32_inclusive(1900000, 2100000) * (u64)rt; do_div(tmp, 1000000); if ((s32)tmp > mrt) { /* multiply 'maximum retransmission time' by 0.9 .. 1.1 */ - tmp = (900000 + prandom_u32_max(200001)) * (u64)mrt; + tmp = get_random_u32_inclusive(900000, 1100000) * (u64)mrt; do_div(tmp, 1000000); } return (s32)tmp; @@ -3127,17 +3127,17 @@ static void add_v4_addrs(struct inet6_dev *idev) offset = sizeof(struct in6_addr) - 4; memcpy(&addr.s6_addr32[3], idev->dev->dev_addr + offset, 4); - if (idev->dev->flags&IFF_POINTOPOINT) { + if (!(idev->dev->flags & IFF_POINTOPOINT) && idev->dev->type == ARPHRD_SIT) { + scope = IPV6_ADDR_COMPATv4; + plen = 96; + pflags |= RTF_NONEXTHOP; + } else { if (idev->cnf.addr_gen_mode == IN6_ADDR_GEN_MODE_NONE) return; addr.s6_addr32[0] = htonl(0xfe800000); scope = IFA_LINK; plen = 64; - } else { - scope = IPV6_ADDR_COMPATv4; - plen = 96; - pflags |= RTF_NONEXTHOP; } if (addr.s6_addr32[3]) { @@ -3320,7 +3320,7 @@ static void addrconf_addr_gen(struct inet6_dev *idev, bool prefix_route) return; /* no link local addresses on devices flagged as slaves */ - if (idev->dev->flags & IFF_SLAVE) + if (idev->dev->priv_flags & IFF_NO_ADDRCONF) return; ipv6_addr_set(&addr, htonl(0xFE800000), 0, 0, 0); @@ -3447,6 +3447,30 @@ static void addrconf_gre_config(struct net_device *dev) } #endif +static void addrconf_init_auto_addrs(struct net_device *dev) +{ + switch (dev->type) { +#if IS_ENABLED(CONFIG_IPV6_SIT) + case ARPHRD_SIT: + addrconf_sit_config(dev); + break; +#endif +#if IS_ENABLED(CONFIG_NET_IPGRE) || IS_ENABLED(CONFIG_IPV6_GRE) + case ARPHRD_IP6GRE: + case ARPHRD_IPGRE: + addrconf_gre_config(dev); + break; +#endif + case ARPHRD_LOOPBACK: + init_loopback(dev); + break; + + default: + addrconf_dev_config(dev); + break; + } +} + static int fixup_permanent_addr(struct net *net, struct inet6_dev *idev, struct inet6_ifaddr *ifp) @@ -3560,7 +3584,7 @@ static int addrconf_notify(struct notifier_block *this, unsigned long event, if (idev && idev->cnf.disable_ipv6) break; - if (dev->flags & IFF_SLAVE) { + if (dev->priv_flags & IFF_NO_ADDRCONF) { if (event == NETDEV_UP && !IS_ERR_OR_NULL(idev) && dev->flags & IFF_UP && dev->flags & IFF_MULTICAST) ipv6_mc_up(idev); @@ -3615,26 +3639,7 @@ static int addrconf_notify(struct notifier_block *this, unsigned long event, run_pending = 1; } - switch (dev->type) { -#if IS_ENABLED(CONFIG_IPV6_SIT) - case ARPHRD_SIT: - addrconf_sit_config(dev); - break; -#endif -#if IS_ENABLED(CONFIG_NET_IPGRE) || IS_ENABLED(CONFIG_IPV6_GRE) - case ARPHRD_IP6GRE: - case ARPHRD_IPGRE: - addrconf_gre_config(dev); - break; -#endif - case ARPHRD_LOOPBACK: - init_loopback(dev); - break; - - default: - addrconf_dev_config(dev); - break; - } + addrconf_init_auto_addrs(dev); if (!IS_ERR_OR_NULL(idev)) { if (run_pending) @@ -3967,7 +3972,7 @@ static void addrconf_dad_kick(struct inet6_ifaddr *ifp) if (ifp->flags & IFA_F_OPTIMISTIC) rand_num = 0; else - rand_num = prandom_u32_max(idev->cnf.rtr_solicit_delay ?: 1); + rand_num = get_random_u32_below(idev->cnf.rtr_solicit_delay ? : 1); nonce = 0; if (idev->cnf.enhanced_dad || @@ -6397,7 +6402,7 @@ static int addrconf_sysctl_addr_gen_mode(struct ctl_table *ctl, int write, if (idev->cnf.addr_gen_mode != new_val) { idev->cnf.addr_gen_mode = new_val; - addrconf_dev_config(idev->dev); + addrconf_init_auto_addrs(idev->dev); } } else if (&net->ipv6.devconf_all->addr_gen_mode == ctl->data) { struct net_device *dev; @@ -6408,7 +6413,7 @@ static int addrconf_sysctl_addr_gen_mode(struct ctl_table *ctl, int write, if (idev && idev->cnf.addr_gen_mode != new_val) { idev->cnf.addr_gen_mode = new_val; - addrconf_dev_config(idev->dev); + addrconf_init_auto_addrs(idev->dev); } } } diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index 024191004982..fee9163382c2 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -114,6 +114,7 @@ void inet6_sock_destruct(struct sock *sk) inet6_cleanup_sock(sk); inet_sock_destruct(sk); } +EXPORT_SYMBOL_GPL(inet6_sock_destruct); static int inet6_create(struct net *net, struct socket *sock, int protocol, int kern) @@ -409,10 +410,10 @@ static int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, /* Make sure we are allowed to bind here. */ if (snum || !(inet->bind_address_no_port || (flags & BIND_FORCE_ADDRESS_NO_PORT))) { - if (sk->sk_prot->get_port(sk, snum)) { + err = sk->sk_prot->get_port(sk, snum); + if (err) { sk->sk_ipv6only = saved_ipv6only; inet_reset_saddr(sk); - err = -EADDRINUSE; goto out; } if (!(flags & BIND_FROM_BPF)) { @@ -489,7 +490,7 @@ int inet6_release(struct socket *sock) } EXPORT_SYMBOL(inet6_release); -void inet6_destroy_sock(struct sock *sk) +void inet6_cleanup_sock(struct sock *sk) { struct ipv6_pinfo *np = inet6_sk(sk); struct sk_buff *skb; @@ -514,12 +515,6 @@ void inet6_destroy_sock(struct sock *sk) txopt_put(opt); } } -EXPORT_SYMBOL_GPL(inet6_destroy_sock); - -void inet6_cleanup_sock(struct sock *sk) -{ - inet6_destroy_sock(sk); -} EXPORT_SYMBOL_GPL(inet6_cleanup_sock); /* diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c index 5ecb56522f9d..e624497fa992 100644 --- a/net/ipv6/datagram.c +++ b/net/ipv6/datagram.c @@ -42,24 +42,29 @@ static void ip6_datagram_flow_key_init(struct flowi6 *fl6, struct sock *sk) { struct inet_sock *inet = inet_sk(sk); struct ipv6_pinfo *np = inet6_sk(sk); + int oif = sk->sk_bound_dev_if; memset(fl6, 0, sizeof(*fl6)); fl6->flowi6_proto = sk->sk_protocol; fl6->daddr = sk->sk_v6_daddr; fl6->saddr = np->saddr; - fl6->flowi6_oif = sk->sk_bound_dev_if; fl6->flowi6_mark = sk->sk_mark; fl6->fl6_dport = inet->inet_dport; fl6->fl6_sport = inet->inet_sport; fl6->flowlabel = np->flow_label; fl6->flowi6_uid = sk->sk_uid; - if (!fl6->flowi6_oif) - fl6->flowi6_oif = np->sticky_pktinfo.ipi6_ifindex; + if (!oif) + oif = np->sticky_pktinfo.ipi6_ifindex; - if (!fl6->flowi6_oif && ipv6_addr_is_multicast(&fl6->daddr)) - fl6->flowi6_oif = np->mcast_oif; + if (!oif) { + if (ipv6_addr_is_multicast(&fl6->daddr)) + oif = np->mcast_oif; + else + oif = np->ucast_oif; + } + fl6->flowi6_oif = oif; security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6)); } @@ -334,6 +339,7 @@ void ipv6_icmp_error(struct sock *sk, struct sk_buff *skb, int err, if (sock_queue_err_skb(sk, skb)) kfree_skb(skb); } +EXPORT_SYMBOL_GPL(ipv6_icmp_error); void ipv6_local_error(struct sock *sk, int err, struct flowi6 *fl6, u32 info) { @@ -771,7 +777,7 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk, } if (cmsg->cmsg_level == SOL_SOCKET) { - err = __sock_cmsg_send(sk, msg, cmsg, &ipc6->sockc); + err = __sock_cmsg_send(sk, cmsg, &ipc6->sockc); if (err) return err; continue; diff --git a/net/ipv6/esp6_offload.c b/net/ipv6/esp6_offload.c index 79d43548279c..75c02992c520 100644 --- a/net/ipv6/esp6_offload.c +++ b/net/ipv6/esp6_offload.c @@ -56,12 +56,11 @@ static struct sk_buff *esp6_gro_receive(struct list_head *head, __be32 seq; __be32 spi; int nhoff; - int err; if (!pskb_pull(skb, offset)) return NULL; - if ((err = xfrm_parse_spi(skb, IPPROTO_ESP, &spi, &seq)) != 0) + if (xfrm_parse_spi(skb, IPPROTO_ESP, &spi, &seq) != 0) goto out; xo = xfrm_offload(skb); @@ -346,6 +345,9 @@ static int esp6_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features xo->seq.low += skb_shinfo(skb)->gso_segs; } + if (xo->seq.low < seq) + xo->seq.hi++; + esp.seqno = cpu_to_be64(xo->seq.low + ((u64)xo->seq.hi << 32)); len = skb->len - sizeof(struct ipv6hdr); diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c index 413f66781e50..2438da5ff6da 100644 --- a/net/ipv6/ip6_fib.c +++ b/net/ipv6/ip6_fib.c @@ -91,13 +91,12 @@ static void fib6_walker_unlink(struct net *net, struct fib6_walker *w) static int fib6_new_sernum(struct net *net) { - int new, old; + int new, old = atomic_read(&net->ipv6.fib6_sernum); do { - old = atomic_read(&net->ipv6.fib6_sernum); new = old < INT_MAX ? old + 1 : 1; - } while (atomic_cmpxchg(&net->ipv6.fib6_sernum, - old, new) != old); + } while (!atomic_try_cmpxchg(&net->ipv6.fib6_sernum, &old, new)); + return new; } diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c index c035a96fba3a..89f5f0f3f5d6 100644 --- a/net/ipv6/ip6_gre.c +++ b/net/ipv6/ip6_gre.c @@ -870,26 +870,6 @@ static inline int ip6gre_xmit_ipv6(struct sk_buff *skb, struct net_device *dev) return 0; } -/** - * ip6gre_tnl_addr_conflict - compare packet addresses to tunnel's own - * @t: the outgoing tunnel device - * @hdr: IPv6 header from the incoming packet - * - * Description: - * Avoid trivial tunneling loop by checking that tunnel exit-point - * doesn't match source of incoming packet. - * - * Return: - * 1 if conflict, - * 0 else - **/ - -static inline bool ip6gre_tnl_addr_conflict(const struct ip6_tnl *t, - const struct ipv6hdr *hdr) -{ - return ipv6_addr_equal(&t->parms.raddr, &hdr->saddr); -} - static int ip6gre_xmit_other(struct sk_buff *skb, struct net_device *dev) { struct ip6_tnl *t = netdev_priv(dev); @@ -915,7 +895,6 @@ static netdev_tx_t ip6gre_tunnel_xmit(struct sk_buff *skb, struct net_device *dev) { struct ip6_tnl *t = netdev_priv(dev); - struct net_device_stats *stats = &t->dev->stats; __be16 payload_protocol; int ret; @@ -945,8 +924,8 @@ static netdev_tx_t ip6gre_tunnel_xmit(struct sk_buff *skb, tx_err: if (!t->parms.collect_md || !IS_ERR(skb_tunnel_info_txcheck(skb))) - stats->tx_errors++; - stats->tx_dropped++; + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return NETDEV_TX_OK; } @@ -957,7 +936,6 @@ static netdev_tx_t ip6erspan_tunnel_xmit(struct sk_buff *skb, struct ip_tunnel_info *tun_info = NULL; struct ip6_tnl *t = netdev_priv(dev); struct dst_entry *dst = skb_dst(skb); - struct net_device_stats *stats; bool truncate = false; int encap_limit = -1; __u8 dsfield = false; @@ -1106,10 +1084,9 @@ static netdev_tx_t ip6erspan_tunnel_xmit(struct sk_buff *skb, return NETDEV_TX_OK; tx_err: - stats = &t->dev->stats; if (!IS_ERR(tun_info)) - stats->tx_errors++; - stats->tx_dropped++; + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return NETDEV_TX_OK; } diff --git a/net/ipv6/ip6_offload.c b/net/ipv6/ip6_offload.c index 3ee345672849..00dc2e3b0184 100644 --- a/net/ipv6/ip6_offload.c +++ b/net/ipv6/ip6_offload.c @@ -77,7 +77,7 @@ static struct sk_buff *ipv6_gso_segment(struct sk_buff *skb, struct sk_buff *segs = ERR_PTR(-EINVAL); struct ipv6hdr *ipv6h; const struct net_offload *ops; - int proto, nexthdr; + int proto, err; struct frag_hdr *fptr; unsigned int payload_len; u8 *prevhdr; @@ -87,28 +87,9 @@ static struct sk_buff *ipv6_gso_segment(struct sk_buff *skb, bool gso_partial; skb_reset_network_header(skb); - nexthdr = ipv6_has_hopopt_jumbo(skb); - if (nexthdr) { - const int hophdr_len = sizeof(struct hop_jumbo_hdr); - int err; - - err = skb_cow_head(skb, 0); - if (err < 0) - return ERR_PTR(err); - - /* remove the HBH header. - * Layout: [Ethernet header][IPv6 header][HBH][TCP header] - */ - memmove(skb_mac_header(skb) + hophdr_len, - skb_mac_header(skb), - ETH_HLEN + sizeof(struct ipv6hdr)); - skb->data += hophdr_len; - skb->len -= hophdr_len; - skb->network_header += hophdr_len; - skb->mac_header += hophdr_len; - ipv6h = (struct ipv6hdr *)skb->data; - ipv6h->nexthdr = nexthdr; - } + err = ipv6_hopopt_jumbo_remove(skb); + if (err) + return ERR_PTR(err); nhoff = skb_network_header(skb) - skb_mac_header(skb); if (unlikely(!pskb_may_pull(skb, sizeof(*ipv6h)))) goto out; diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c index e19507614f64..c314fdde0097 100644 --- a/net/ipv6/ip6_output.c +++ b/net/ipv6/ip6_output.c @@ -547,7 +547,20 @@ int ip6_forward(struct sk_buff *skb) pneigh_lookup(&nd_tbl, net, &hdr->daddr, skb->dev, 0)) { int proxied = ip6_forward_proxy_check(skb); if (proxied > 0) { - hdr->hop_limit--; + /* It's tempting to decrease the hop limit + * here by 1, as we do at the end of the + * function too. + * + * But that would be incorrect, as proxying is + * not forwarding. The ip6_input function + * will handle this packet locally, and it + * depends on the hop limit being unchanged. + * + * One example is the NDP hop limit, that + * always has to stay 255, but other would be + * similar checks around RA packets, where the + * user can even change the desired limit. + */ return ip6_input(skb); } else if (proxied < 0) { __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); @@ -920,6 +933,9 @@ int ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, if (err < 0) goto fail; + /* We prevent @rt from being freed. */ + rcu_read_lock(); + for (;;) { /* Prepare header of the next frame, * before previous one went down. */ @@ -943,6 +959,7 @@ int ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, if (err == 0) { IP6_INC_STATS(net, ip6_dst_idev(&rt->dst), IPSTATS_MIB_FRAGOKS); + rcu_read_unlock(); return 0; } @@ -950,6 +967,7 @@ int ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, IP6_INC_STATS(net, ip6_dst_idev(&rt->dst), IPSTATS_MIB_FRAGFAILS); + rcu_read_unlock(); return err; slow_path_clean: diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c index 2fb4c6ad7243..47b6607a1370 100644 --- a/net/ipv6/ip6_tunnel.c +++ b/net/ipv6/ip6_tunnel.c @@ -803,8 +803,8 @@ static int __ip6_tnl_rcv(struct ip6_tnl *tunnel, struct sk_buff *skb, (tunnel->parms.i_flags & TUNNEL_CSUM)) || ((tpi->flags & TUNNEL_CSUM) && !(tunnel->parms.i_flags & TUNNEL_CSUM))) { - tunnel->dev->stats.rx_crc_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_crc_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } @@ -812,8 +812,8 @@ static int __ip6_tnl_rcv(struct ip6_tnl *tunnel, struct sk_buff *skb, if (!(tpi->flags & TUNNEL_SEQ) || (tunnel->i_seqno && (s32)(ntohl(tpi->seq) - tunnel->i_seqno) < 0)) { - tunnel->dev->stats.rx_fifo_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_fifo_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } tunnel->i_seqno = ntohl(tpi->seq) + 1; @@ -824,8 +824,8 @@ static int __ip6_tnl_rcv(struct ip6_tnl *tunnel, struct sk_buff *skb, /* Warning: All skb pointers will be invalidated! */ if (tunnel->dev->type == ARPHRD_ETHER) { if (!pskb_may_pull(skb, ETH_HLEN)) { - tunnel->dev->stats.rx_length_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_length_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } @@ -849,8 +849,8 @@ static int __ip6_tnl_rcv(struct ip6_tnl *tunnel, struct sk_buff *skb, &ipv6h->saddr, ipv6_get_dsfield(ipv6h)); if (err > 1) { - ++tunnel->dev->stats.rx_frame_errors; - ++tunnel->dev->stats.rx_errors; + DEV_STATS_INC(tunnel->dev, rx_frame_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } } @@ -1071,7 +1071,6 @@ int ip6_tnl_xmit(struct sk_buff *skb, struct net_device *dev, __u8 dsfield, { struct ip6_tnl *t = netdev_priv(dev); struct net *net = t->net; - struct net_device_stats *stats = &t->dev->stats; struct ipv6hdr *ipv6h; struct ipv6_tel_txoption opt; struct dst_entry *dst = NULL, *ndst = NULL; @@ -1166,7 +1165,7 @@ route_lookup: tdev = dst->dev; if (tdev == dev) { - stats->collisions++; + DEV_STATS_INC(dev, collisions); net_warn_ratelimited("%s: Local routing loop detected!\n", t->parms.name); goto tx_err_dst_release; @@ -1265,7 +1264,7 @@ route_lookup: ip6tunnel_xmit(NULL, skb, dev); return 0; tx_err_link_failure: - stats->tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); dst_link_failure(skb); tx_err_dst_release: dst_release(dst); @@ -1408,7 +1407,6 @@ static netdev_tx_t ip6_tnl_start_xmit(struct sk_buff *skb, struct net_device *dev) { struct ip6_tnl *t = netdev_priv(dev); - struct net_device_stats *stats = &t->dev->stats; u8 ipproto; int ret; @@ -1438,8 +1436,8 @@ ip6_tnl_start_xmit(struct sk_buff *skb, struct net_device *dev) return NETDEV_TX_OK; tx_err: - stats->tx_errors++; - stats->tx_dropped++; + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return NETDEV_TX_OK; } diff --git a/net/ipv6/ip6_vti.c b/net/ipv6/ip6_vti.c index 151337d7f67b..10b222865d46 100644 --- a/net/ipv6/ip6_vti.c +++ b/net/ipv6/ip6_vti.c @@ -317,7 +317,7 @@ static int vti6_input_proto(struct sk_buff *skb, int nexthdr, __be32 spi, ipv6h = ipv6_hdr(skb); if (!ip6_tnl_rcv_ctl(t, &ipv6h->daddr, &ipv6h->saddr)) { - t->dev->stats.rx_dropped++; + DEV_STATS_INC(t->dev, rx_dropped); rcu_read_unlock(); goto discard; } @@ -359,8 +359,8 @@ static int vti6_rcv_cb(struct sk_buff *skb, int err) dev = t->dev; if (err) { - dev->stats.rx_errors++; - dev->stats.rx_dropped++; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_dropped); return 0; } @@ -446,7 +446,6 @@ static int vti6_xmit(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) { struct ip6_tnl *t = netdev_priv(dev); - struct net_device_stats *stats = &t->dev->stats; struct dst_entry *dst = skb_dst(skb); struct net_device *tdev; struct xfrm_state *x; @@ -506,7 +505,7 @@ vti6_xmit(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) tdev = dst->dev; if (tdev == dev) { - stats->collisions++; + DEV_STATS_INC(dev, collisions); net_warn_ratelimited("%s: Local routing loop detected!\n", t->parms.name); goto tx_err_dst_release; @@ -544,7 +543,7 @@ xmit: return 0; tx_err_link_failure: - stats->tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); dst_link_failure(skb); tx_err_dst_release: dst_release(dst); @@ -555,7 +554,6 @@ static netdev_tx_t vti6_tnl_xmit(struct sk_buff *skb, struct net_device *dev) { struct ip6_tnl *t = netdev_priv(dev); - struct net_device_stats *stats = &t->dev->stats; struct flowi fl; int ret; @@ -591,8 +589,8 @@ vti6_tnl_xmit(struct sk_buff *skb, struct net_device *dev) return NETDEV_TX_OK; tx_err: - stats->tx_errors++; - stats->tx_dropped++; + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return NETDEV_TX_OK; } diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index facdc78a43e5..51cf37abd142 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -392,7 +392,7 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id) static void ip6mr_free_table(struct mr_table *mrt) { - del_timer_sync(&mrt->ipmr_expire_timer); + timer_shutdown_sync(&mrt->ipmr_expire_timer); mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC | MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC); rhltable_destroy(&mrt->mfc_hash); @@ -608,8 +608,8 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0) goto tx_err; - dev->stats.tx_bytes += skb->len; - dev->stats.tx_packets++; + DEV_STATS_ADD(dev, tx_bytes, skb->len); + DEV_STATS_INC(dev, tx_packets); rcu_read_lock(); ip6mr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num), MRT6MSG_WHOLEPKT); @@ -618,7 +618,7 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, return NETDEV_TX_OK; tx_err: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; } @@ -2044,8 +2044,8 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt, if (vif->flags & MIFF_REGISTER) { WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); - vif_dev->stats.tx_bytes += skb->len; - vif_dev->stats.tx_packets++; + DEV_STATS_ADD(vif_dev, tx_bytes, skb->len); + DEV_STATS_INC(vif_dev, tx_packets); ip6mr_cache_report(mrt, skb, vifi, MRT6MSG_WHOLEPKT); goto out_free; } diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c index 532f4478c884..9ce51680290b 100644 --- a/net/ipv6/ipv6_sockglue.c +++ b/net/ipv6/ipv6_sockglue.c @@ -1005,10 +1005,8 @@ unlock: return retv; e_inval: - sockopt_release_sock(sk); - if (needs_rtnl) - rtnl_unlock(); - return -EINVAL; + retv = -EINVAL; + goto unlock; } int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c index 7860383295d8..1c02160cf7a4 100644 --- a/net/ipv6/mcast.c +++ b/net/ipv6/mcast.c @@ -1050,7 +1050,7 @@ bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group, /* called with mc_lock */ static void mld_gq_start_work(struct inet6_dev *idev) { - unsigned long tv = prandom_u32_max(idev->mc_maxdelay); + unsigned long tv = get_random_u32_below(idev->mc_maxdelay); idev->mc_gq_running = 1; if (!mod_delayed_work(mld_wq, &idev->mc_gq_work, tv + 2)) @@ -1068,7 +1068,7 @@ static void mld_gq_stop_work(struct inet6_dev *idev) /* called with mc_lock */ static void mld_ifc_start_work(struct inet6_dev *idev, unsigned long delay) { - unsigned long tv = prandom_u32_max(delay); + unsigned long tv = get_random_u32_below(delay); if (!mod_delayed_work(mld_wq, &idev->mc_ifc_work, tv + 2)) in6_dev_hold(idev); @@ -1085,7 +1085,7 @@ static void mld_ifc_stop_work(struct inet6_dev *idev) /* called with mc_lock */ static void mld_dad_start_work(struct inet6_dev *idev, unsigned long delay) { - unsigned long tv = prandom_u32_max(delay); + unsigned long tv = get_random_u32_below(delay); if (!mod_delayed_work(mld_wq, &idev->mc_dad_work, tv + 2)) in6_dev_hold(idev); @@ -1130,7 +1130,7 @@ static void igmp6_group_queried(struct ifmcaddr6 *ma, unsigned long resptime) } if (delay >= resptime) - delay = prandom_u32_max(resptime); + delay = get_random_u32_below(resptime); if (!mod_delayed_work(mld_wq, &ma->mca_work, delay)) refcount_inc(&ma->mca_refcnt); @@ -2574,7 +2574,7 @@ static void igmp6_join_group(struct ifmcaddr6 *ma) igmp6_send(&ma->mca_addr, ma->idev->dev, ICMPV6_MGM_REPORT); - delay = prandom_u32_max(unsolicited_report_interval(ma->idev)); + delay = get_random_u32_below(unsolicited_report_interval(ma->idev)); if (cancel_delayed_work(&ma->mca_work)) { refcount_dec(&ma->mca_refcnt); diff --git a/net/ipv6/netfilter/nf_conntrack_reasm.c b/net/ipv6/netfilter/nf_conntrack_reasm.c index 38db0064d661..d13240f13607 100644 --- a/net/ipv6/netfilter/nf_conntrack_reasm.c +++ b/net/ipv6/netfilter/nf_conntrack_reasm.c @@ -253,7 +253,7 @@ static int nf_ct_frag6_queue(struct frag_queue *fq, struct sk_buff *skb, if (err) { if (err == IPFRAG_DUP) { /* No error for duplicates, pretend they got queued. */ - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_DUP_FRAG); return -EINPROGRESS; } goto insert_error; diff --git a/net/ipv6/netfilter/nft_dup_ipv6.c b/net/ipv6/netfilter/nft_dup_ipv6.c index 70a405b4006f..c82f3fdd4a65 100644 --- a/net/ipv6/netfilter/nft_dup_ipv6.c +++ b/net/ipv6/netfilter/nft_dup_ipv6.c @@ -50,7 +50,8 @@ static int nft_dup_ipv6_init(const struct nft_ctx *ctx, return err; } -static int nft_dup_ipv6_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_dup_ipv6_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_dup_ipv6 *priv = nft_expr_priv(expr); diff --git a/net/ipv6/output_core.c b/net/ipv6/output_core.c index 2685c3f15e9d..b5205311f372 100644 --- a/net/ipv6/output_core.c +++ b/net/ipv6/output_core.c @@ -15,13 +15,7 @@ static u32 __ipv6_select_ident(struct net *net, const struct in6_addr *dst, const struct in6_addr *src) { - u32 id; - - do { - id = get_random_u32(); - } while (!id); - - return id; + return get_random_u32_above(0); } /* This function exists only for tap drivers that must support broken diff --git a/net/ipv6/ping.c b/net/ipv6/ping.c index 86c26e48d065..808983bc2ec9 100644 --- a/net/ipv6/ping.c +++ b/net/ipv6/ping.c @@ -23,11 +23,6 @@ #include <linux/bpf-cgroup.h> #include <net/ping.h> -static void ping_v6_destroy(struct sock *sk) -{ - inet6_destroy_sock(sk); -} - /* Compatibility glue so we can support IPv6 when it's compiled as a module */ static int dummy_ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) @@ -205,7 +200,6 @@ struct proto pingv6_prot = { .owner = THIS_MODULE, .init = ping_init_sock, .close = ping_close, - .destroy = ping_v6_destroy, .pre_connect = ping_v6_pre_connect, .connect = ip6_datagram_connect_v6_only, .disconnect = __udp_disconnect, diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c index 722de9dd0ff7..ada087b50541 100644 --- a/net/ipv6/raw.c +++ b/net/ipv6/raw.c @@ -505,6 +505,7 @@ csum_copy_err: static int rawv6_push_pending_frames(struct sock *sk, struct flowi6 *fl6, struct raw6_sock *rp) { + struct ipv6_txoptions *opt; struct sk_buff *skb; int err = 0; int offset; @@ -522,6 +523,9 @@ static int rawv6_push_pending_frames(struct sock *sk, struct flowi6 *fl6, offset = rp->offset; total_len = inet_sk(sk)->cork.base.length; + opt = inet6_sk(sk)->cork.opt; + total_len -= opt ? opt->opt_flen : 0; + if (offset >= total_len - 1) { err = -EINVAL; ip6_flush_pending_frames(sk); @@ -1173,8 +1177,6 @@ static void raw6_destroy(struct sock *sk) lock_sock(sk); ip6_flush_pending_frames(sk); release_sock(sk); - - inet6_destroy_sock(sk); } static int rawv6_init_sk(struct sock *sk) diff --git a/net/ipv6/reassembly.c b/net/ipv6/reassembly.c index ff866f2a879e..5bc8a28e67f9 100644 --- a/net/ipv6/reassembly.c +++ b/net/ipv6/reassembly.c @@ -112,10 +112,14 @@ static int ip6_frag_queue(struct frag_queue *fq, struct sk_buff *skb, struct sk_buff *prev_tail; struct net_device *dev; int err = -ENOENT; + SKB_DR(reason); u8 ecn; - if (fq->q.flags & INET_FRAG_COMPLETE) + /* If reassembly is already done, @skb must be a duplicate frag. */ + if (fq->q.flags & INET_FRAG_COMPLETE) { + SKB_DR_SET(reason, DUP_FRAG); goto err; + } err = -EINVAL; offset = ntohs(fhdr->frag_off) & ~0x7; @@ -226,8 +230,9 @@ static int ip6_frag_queue(struct frag_queue *fq, struct sk_buff *skb, insert_error: if (err == IPFRAG_DUP) { - kfree_skb(skb); - return -EINVAL; + SKB_DR_SET(reason, DUP_FRAG); + err = -EINVAL; + goto err; } err = -EINVAL; __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), @@ -237,7 +242,7 @@ discard_fq: __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_REASMFAILS); err: - kfree_skb(skb); + kfree_skb_reason(skb, reason); return err; } diff --git a/net/ipv6/route.c b/net/ipv6/route.c index 2f355f0ec32a..e74e0361fd92 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -1713,7 +1713,7 @@ static int rt6_insert_exception(struct rt6_info *nrt, net->ipv6.rt6_stats->fib_rt_cache++; /* Randomize max depth to avoid some side channels attacks. */ - max_depth = FIB6_MAX_DEPTH + prandom_u32_max(FIB6_MAX_DEPTH); + max_depth = FIB6_MAX_DEPTH + get_random_u32_below(FIB6_MAX_DEPTH); while (bucket->depth > max_depth) rt6_exception_remove_oldest(bucket); diff --git a/net/ipv6/seg6_local.c b/net/ipv6/seg6_local.c index 8370726ae7bf..487f8e98deaa 100644 --- a/net/ipv6/seg6_local.c +++ b/net/ipv6/seg6_local.c @@ -1644,13 +1644,13 @@ static int put_nla_counters(struct sk_buff *skb, struct seg6_local_lwt *slwt) pcounters = per_cpu_ptr(slwt->pcpu_counters, i); do { - start = u64_stats_fetch_begin_irq(&pcounters->syncp); + start = u64_stats_fetch_begin(&pcounters->syncp); packets = u64_stats_read(&pcounters->packets); bytes = u64_stats_read(&pcounters->bytes); errors = u64_stats_read(&pcounters->errors); - } while (u64_stats_fetch_retry_irq(&pcounters->syncp, start)); + } while (u64_stats_fetch_retry(&pcounters->syncp, start)); counters.packets += packets; counters.bytes += bytes; diff --git a/net/ipv6/sit.c b/net/ipv6/sit.c index 5703d3cbea9b..70d81bba5093 100644 --- a/net/ipv6/sit.c +++ b/net/ipv6/sit.c @@ -694,7 +694,7 @@ static int ipip6_rcv(struct sk_buff *skb) skb->dev = tunnel->dev; if (packet_is_spoofed(skb, iph, tunnel)) { - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_errors); goto out; } @@ -714,8 +714,8 @@ static int ipip6_rcv(struct sk_buff *skb) net_info_ratelimited("non-ECT from %pI4 with TOS=%#x\n", &iph->saddr, iph->tos); if (err > 1) { - ++tunnel->dev->stats.rx_frame_errors; - ++tunnel->dev->stats.rx_errors; + DEV_STATS_INC(tunnel->dev, rx_frame_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto out; } } @@ -942,7 +942,7 @@ static netdev_tx_t ipip6_tunnel_xmit(struct sk_buff *skb, if (!rt) { rt = ip_route_output_flow(tunnel->net, &fl4, NULL); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } dst_cache_set_ip4(&tunnel->dst_cache, &rt->dst, fl4.saddr); @@ -950,14 +950,14 @@ static netdev_tx_t ipip6_tunnel_xmit(struct sk_buff *skb, if (rt->rt_type != RTN_UNICAST && rt->rt_type != RTN_LOCAL) { ip_rt_put(rt); - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } tdev = rt->dst.dev; if (tdev == dev) { ip_rt_put(rt); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } @@ -970,7 +970,7 @@ static netdev_tx_t ipip6_tunnel_xmit(struct sk_buff *skb, mtu = dst_mtu(&rt->dst) - t_hlen; if (mtu < IPV4_MIN_MTU) { - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); ip_rt_put(rt); goto tx_error; } @@ -1009,7 +1009,7 @@ static netdev_tx_t ipip6_tunnel_xmit(struct sk_buff *skb, struct sk_buff *new_skb = skb_realloc_headroom(skb, max_headroom); if (!new_skb) { ip_rt_put(rt); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return NETDEV_TX_OK; } @@ -1039,7 +1039,7 @@ tx_error_icmp: dst_link_failure(skb); tx_error: kfree_skb(skb); - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); return NETDEV_TX_OK; } @@ -1058,7 +1058,7 @@ static netdev_tx_t sit_tunnel_xmit__(struct sk_buff *skb, return NETDEV_TX_OK; tx_error: kfree_skb(skb); - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); return NETDEV_TX_OK; } @@ -1087,7 +1087,7 @@ static netdev_tx_t sit_tunnel_xmit(struct sk_buff *skb, return NETDEV_TX_OK; tx_err: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 2a3f9296df1e..11b736a76bd7 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -292,24 +292,11 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; if (!saddr) { - struct inet_bind_hashbucket *prev_addr_hashbucket = NULL; - struct in6_addr prev_v6_rcv_saddr; - - if (icsk->icsk_bind2_hash) { - prev_addr_hashbucket = inet_bhashfn_portaddr(tcp_death_row->hashinfo, - sk, net, inet->inet_num); - prev_v6_rcv_saddr = sk->sk_v6_rcv_saddr; - } saddr = &fl6.saddr; - sk->sk_v6_rcv_saddr = *saddr; - if (prev_addr_hashbucket) { - err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); - if (err) { - sk->sk_v6_rcv_saddr = prev_v6_rcv_saddr; - goto failure; - } - } + err = inet_bhash2_update_saddr(sk, saddr, AF_INET6); + if (err) + goto failure; } /* set the source address */ @@ -359,6 +346,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, late_failure: tcp_set_state(sk, TCP_CLOSE); + inet_bhash2_reset_saddr(sk); failure: inet->inet_dport = 0; sk->sk_route_caps = 0; @@ -677,12 +665,11 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname, if (ipv6_addr_v4mapped(&sin6->sin6_addr)) return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3], AF_INET, prefixlen, l3index, flags, - cmd.tcpm_key, cmd.tcpm_keylen, - GFP_KERNEL); + cmd.tcpm_key, cmd.tcpm_keylen); return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr, AF_INET6, prefixlen, l3index, flags, - cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); + cmd.tcpm_key, cmd.tcpm_keylen); } static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp, @@ -1377,14 +1364,14 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff * /* Copy over the MD5 key from the original socket */ key = tcp_v6_md5_do_lookup(sk, &newsk->sk_v6_daddr, l3index); if (key) { - /* We're using one, so create a matching key - * on the newsk structure. If we fail to get - * memory, then we end up not copying the key - * across. Shucks. - */ - tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr, - AF_INET6, 128, l3index, key->flags, key->key, key->keylen, - sk_gfp_mask(sk, GFP_ATOMIC)); + const union tcp_md5_addr *addr; + + addr = (union tcp_md5_addr *)&newsk->sk_v6_daddr; + if (tcp_md5_key_copy(newsk, addr, AF_INET6, 128, l3index, key)) { + inet_csk_prepare_forced_close(newsk); + tcp_done(newsk); + goto out; + } } #endif @@ -1966,12 +1953,6 @@ static int tcp_v6_init_sock(struct sock *sk) return 0; } -static void tcp_v6_destroy_sock(struct sock *sk) -{ - tcp_v4_destroy_sock(sk); - inet6_destroy_sock(sk); -} - #ifdef CONFIG_PROC_FS /* Proc filesystem TCPv6 sock list dumping. */ static void get_openreq6(struct seq_file *seq, @@ -2164,7 +2145,7 @@ struct proto tcpv6_prot = { .accept = inet_csk_accept, .ioctl = tcp_ioctl, .init = tcp_v6_init_sock, - .destroy = tcp_v6_destroy_sock, + .destroy = tcp_v4_destroy_sock, .shutdown = tcp_shutdown, .setsockopt = tcp_setsockopt, .getsockopt = tcp_getsockopt, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index bc65e5b7195b..9fb2f33ee3a7 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -64,7 +64,7 @@ static void udpv6_destruct_sock(struct sock *sk) int udpv6_init_sock(struct sock *sk) { - skb_queue_head_init(&udp_sk(sk)->reader_queue); + udp_lib_init_sock(sk); sk->sk_destruct = udpv6_destruct_sock; set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); return 0; @@ -217,7 +217,7 @@ static inline struct sock *udp6_lookup_run_bpf(struct net *net, struct sock *sk, *reuse_sk; bool no_reuseport; - if (udptable != &udp_table) + if (udptable != net->ipv4.udp_table) return NULL; /* only UDP is supported */ no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP, saddr, sport, @@ -298,10 +298,11 @@ struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb, __be16 sport, __be16 dport) { const struct ipv6hdr *iph = ipv6_hdr(skb); + struct net *net = dev_net(skb->dev); - return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, + return __udp6_lib_lookup(net, &iph->saddr, sport, &iph->daddr, dport, inet6_iif(skb), - inet6_sdif(skb), &udp_table, NULL); + inet6_sdif(skb), net->ipv4.udp_table, NULL); } /* Must be called under rcu_read_lock(). @@ -314,7 +315,7 @@ struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be struct sock *sk; sk = __udp6_lib_lookup(net, saddr, sport, daddr, dport, - dif, 0, &udp_table, NULL); + dif, 0, net->ipv4.udp_table, NULL); if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) sk = NULL; return sk; @@ -632,7 +633,8 @@ int __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt, /* Tunnels don't have an application socket: don't pass errors back */ if (tunnel) { if (udp_sk(sk)->encap_err_rcv) - udp_sk(sk)->encap_err_rcv(sk, skb, offset); + udp_sk(sk)->encap_err_rcv(sk, skb, err, uh->dest, + ntohl(info), (u8 *)(uh+1)); goto out; } @@ -688,7 +690,8 @@ static __inline__ int udpv6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, u8 type, u8 code, int offset, __be32 info) { - return __udp6_lib_err(skb, opt, type, code, offset, info, &udp_table); + return __udp6_lib_err(skb, opt, type, code, offset, info, + dev_net(skb->dev)->ipv4.udp_table); } static int udpv6_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) @@ -1062,13 +1065,18 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net, __be16 rmt_port, const struct in6_addr *rmt_addr, int dif, int sdif) { + struct udp_table *udptable = net->ipv4.udp_table; unsigned short hnum = ntohs(loc_port); - unsigned int hash2 = ipv6_portaddr_hash(net, loc_addr, hnum); - unsigned int slot2 = hash2 & udp_table.mask; - struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; - const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + __portpair ports; struct sock *sk; + hash2 = ipv6_portaddr_hash(net, loc_addr, hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + ports = INET_COMBINED_PORTS(rmt_port, hnum); + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { if (sk->sk_state == TCP_ESTABLISHED && inet6_match(net, sk, rmt_addr, loc_addr, ports, dif, sdif)) @@ -1122,7 +1130,7 @@ void udp_v6_early_demux(struct sk_buff *skb) INDIRECT_CALLABLE_SCOPE int udpv6_rcv(struct sk_buff *skb) { - return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP); + return __udp6_lib_rcv(skb, dev_net(skb->dev)->ipv4.udp_table, IPPROTO_UDP); } /* @@ -1639,6 +1647,7 @@ do_confirm: err = 0; goto out; } +EXPORT_SYMBOL(udpv6_sendmsg); void udpv6_destroy_sock(struct sock *sk) { @@ -1662,8 +1671,6 @@ void udpv6_destroy_sock(struct sock *sk) udp_encap_disable(); } } - - inet6_destroy_sock(sk); } /* @@ -1672,7 +1679,7 @@ void udpv6_destroy_sock(struct sock *sk) int udpv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, unsigned int optlen) { - if (level == SOL_UDP || level == SOL_UDPLITE) + if (level == SOL_UDP || level == SOL_UDPLITE || level == SOL_SOCKET) return udp_lib_setsockopt(sk, level, optname, optval, optlen, udp_v6_push_pending_frames); @@ -1720,7 +1727,7 @@ EXPORT_SYMBOL(udp6_seq_ops); static struct udp_seq_afinfo udp6_seq_afinfo = { .family = AF_INET6, - .udp_table = &udp_table, + .udp_table = NULL, }; int __net_init udp6_proc_init(struct net *net) @@ -1770,7 +1777,7 @@ struct proto udpv6_prot = { .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), .obj_size = sizeof(struct udp6_sock), - .h.udp_table = &udp_table, + .h.udp_table = NULL, .diag_destroy = udp_abort, }; diff --git a/net/ipv6/udp_offload.c b/net/ipv6/udp_offload.c index 7720d04ed396..c39c1e32f980 100644 --- a/net/ipv6/udp_offload.c +++ b/net/ipv6/udp_offload.c @@ -42,7 +42,8 @@ static struct sk_buff *udp6_ufo_fragment(struct sk_buff *skb, if (!pskb_may_pull(skb, sizeof(struct udphdr))) goto out; - if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) + if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4 && + !skb_gso_ok(skb, features | NETIF_F_GSO_ROBUST)) return __udp_gso_segment(skb, features, true); mss = skb_shinfo(skb)->gso_size; @@ -116,10 +117,11 @@ static struct sock *udp6_gro_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport) { const struct ipv6hdr *iph = skb_gro_network_header(skb); + struct net *net = dev_net(skb->dev); - return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, + return __udp6_lib_lookup(net, &iph->saddr, sport, &iph->daddr, dport, inet6_iif(skb), - inet6_sdif(skb), &udp_table, NULL); + inet6_sdif(skb), net->ipv4.udp_table, NULL); } INDIRECT_CALLABLE_SCOPE diff --git a/net/ipv6/xfrm6_policy.c b/net/ipv6/xfrm6_policy.c index 4a4b0e49ec92..ea435eba3053 100644 --- a/net/ipv6/xfrm6_policy.c +++ b/net/ipv6/xfrm6_policy.c @@ -287,9 +287,13 @@ int __init xfrm6_init(void) if (ret) goto out_state; - register_pernet_subsys(&xfrm6_net_ops); + ret = register_pernet_subsys(&xfrm6_net_ops); + if (ret) + goto out_protocol; out: return ret; +out_protocol: + xfrm6_protocol_fini(); out_state: xfrm6_state_fini(); out_policy: diff --git a/net/kcm/kcmsock.c b/net/kcm/kcmsock.c index a5004228111d..890a2423f559 100644 --- a/net/kcm/kcmsock.c +++ b/net/kcm/kcmsock.c @@ -222,7 +222,7 @@ static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head) struct sk_buff *skb; struct kcm_sock *kcm; - while ((skb = __skb_dequeue(head))) { + while ((skb = skb_dequeue(head))) { /* Reset destructor to avoid calling kcm_rcv_ready */ skb->destructor = sock_rfree; skb_orphan(skb); @@ -1085,53 +1085,17 @@ out_error: return err; } -static struct sk_buff *kcm_wait_data(struct sock *sk, int flags, - long timeo, int *err) -{ - struct sk_buff *skb; - - while (!(skb = skb_peek(&sk->sk_receive_queue))) { - if (sk->sk_err) { - *err = sock_error(sk); - return NULL; - } - - if (sock_flag(sk, SOCK_DONE)) - return NULL; - - if ((flags & MSG_DONTWAIT) || !timeo) { - *err = -EAGAIN; - return NULL; - } - - sk_wait_data(sk, &timeo, NULL); - - /* Handle signals */ - if (signal_pending(current)) { - *err = sock_intr_errno(timeo); - return NULL; - } - } - - return skb; -} - static int kcm_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { struct sock *sk = sock->sk; struct kcm_sock *kcm = kcm_sk(sk); int err = 0; - long timeo; struct strp_msg *stm; int copied = 0; struct sk_buff *skb; - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - - lock_sock(sk); - - skb = kcm_wait_data(sk, flags, timeo, &err); + skb = skb_recv_datagram(sk, flags, &err); if (!skb) goto out; @@ -1162,14 +1126,11 @@ msg_finished: /* Finished with message */ msg->msg_flags |= MSG_EOR; KCM_STATS_INCR(kcm->stats.rx_msgs); - skb_unlink(skb, &sk->sk_receive_queue); - kfree_skb(skb); } } out: - release_sock(sk); - + skb_free_datagram(sk, skb); return copied ? : err; } @@ -1179,7 +1140,6 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos, { struct sock *sk = sock->sk; struct kcm_sock *kcm = kcm_sk(sk); - long timeo; struct strp_msg *stm; int err = 0; ssize_t copied; @@ -1187,11 +1147,7 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos, /* Only support splice for SOCKSEQPACKET */ - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - - lock_sock(sk); - - skb = kcm_wait_data(sk, flags, timeo, &err); + skb = skb_recv_datagram(sk, flags, &err); if (!skb) goto err_out; @@ -1219,13 +1175,11 @@ static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos, * finish reading the message. */ - release_sock(sk); - + skb_free_datagram(sk, skb); return copied; err_out: - release_sock(sk); - + skb_free_datagram(sk, skb); return err; } diff --git a/net/key/af_key.c b/net/key/af_key.c index c85df5b958d2..2bdbcec781cd 100644 --- a/net/key/af_key.c +++ b/net/key/af_key.c @@ -1377,13 +1377,13 @@ static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_ max_spi = range->sadb_spirange_max; } - err = verify_spi_info(x->id.proto, min_spi, max_spi); + err = verify_spi_info(x->id.proto, min_spi, max_spi, NULL); if (err) { xfrm_state_put(x); return err; } - err = xfrm_alloc_spi(x, min_spi, max_spi); + err = xfrm_alloc_spi(x, min_spi, max_spi, NULL); resp_skb = err ? ERR_PTR(err) : pfkey_xfrm_state2msg(x); if (IS_ERR(resp_skb)) { @@ -2626,7 +2626,7 @@ static int pfkey_migrate(struct sock *sk, struct sk_buff *skb, } return xfrm_migrate(&sel, dir, XFRM_POLICY_TYPE_MAIN, m, i, - kma ? &k : NULL, net, NULL, 0); + kma ? &k : NULL, net, NULL, 0, NULL); out: return err; @@ -2905,7 +2905,7 @@ static int count_ah_combs(const struct xfrm_tmpl *t) break; if (!aalg->pfkey_supported) continue; - if (aalg_tmpl_set(t, aalg) && aalg->available) + if (aalg_tmpl_set(t, aalg)) sz += sizeof(struct sadb_comb); } return sz + sizeof(struct sadb_prop); @@ -2923,7 +2923,7 @@ static int count_esp_combs(const struct xfrm_tmpl *t) if (!ealg->pfkey_supported) continue; - if (!(ealg_tmpl_set(t, ealg) && ealg->available)) + if (!(ealg_tmpl_set(t, ealg))) continue; for (k = 1; ; k++) { @@ -2934,16 +2934,17 @@ static int count_esp_combs(const struct xfrm_tmpl *t) if (!aalg->pfkey_supported) continue; - if (aalg_tmpl_set(t, aalg) && aalg->available) + if (aalg_tmpl_set(t, aalg)) sz += sizeof(struct sadb_comb); } } return sz + sizeof(struct sadb_prop); } -static void dump_ah_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) +static int dump_ah_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) { struct sadb_prop *p; + int sz = 0; int i; p = skb_put(skb, sizeof(struct sadb_prop)); @@ -2971,13 +2972,17 @@ static void dump_ah_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) c->sadb_comb_soft_addtime = 20*60*60; c->sadb_comb_hard_usetime = 8*60*60; c->sadb_comb_soft_usetime = 7*60*60; + sz += sizeof(*c); } } + + return sz + sizeof(*p); } -static void dump_esp_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) +static int dump_esp_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) { struct sadb_prop *p; + int sz = 0; int i, k; p = skb_put(skb, sizeof(struct sadb_prop)); @@ -3019,8 +3024,11 @@ static void dump_esp_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) c->sadb_comb_soft_addtime = 20*60*60; c->sadb_comb_hard_usetime = 8*60*60; c->sadb_comb_soft_usetime = 7*60*60; + sz += sizeof(*c); } } + + return sz + sizeof(*p); } static int key_notify_policy_expire(struct xfrm_policy *xp, const struct km_event *c) @@ -3150,6 +3158,7 @@ static int pfkey_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *t, struct struct sadb_x_sec_ctx *sec_ctx; struct xfrm_sec_ctx *xfrm_ctx; int ctx_size = 0; + int alg_size = 0; sockaddr_size = pfkey_sockaddr_size(x->props.family); if (!sockaddr_size) @@ -3161,16 +3170,16 @@ static int pfkey_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *t, struct sizeof(struct sadb_x_policy); if (x->id.proto == IPPROTO_AH) - size += count_ah_combs(t); + alg_size = count_ah_combs(t); else if (x->id.proto == IPPROTO_ESP) - size += count_esp_combs(t); + alg_size = count_esp_combs(t); if ((xfrm_ctx = x->security)) { ctx_size = PFKEY_ALIGN8(xfrm_ctx->ctx_len); size += sizeof(struct sadb_x_sec_ctx) + ctx_size; } - skb = alloc_skb(size + 16, GFP_ATOMIC); + skb = alloc_skb(size + alg_size + 16, GFP_ATOMIC); if (skb == NULL) return -ENOMEM; @@ -3224,10 +3233,13 @@ static int pfkey_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *t, struct pol->sadb_x_policy_priority = xp->priority; /* Set sadb_comb's. */ + alg_size = 0; if (x->id.proto == IPPROTO_AH) - dump_ah_combs(skb, t); + alg_size = dump_ah_combs(skb, t); else if (x->id.proto == IPPROTO_ESP) - dump_esp_combs(skb, t); + alg_size = dump_esp_combs(skb, t); + + hdr->sadb_msg_len += alg_size / 8; /* security context */ if (xfrm_ctx) { @@ -3382,7 +3394,7 @@ static int pfkey_send_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, hdr->sadb_msg_len = size / sizeof(uint64_t); hdr->sadb_msg_errno = 0; hdr->sadb_msg_reserved = 0; - hdr->sadb_msg_seq = x->km.seq = get_acqseq(); + hdr->sadb_msg_seq = x->km.seq; hdr->sadb_msg_pid = 0; /* SA */ diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 7499c51b1850..03608d3ded4b 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq; /* per-net private data for this module */ static unsigned int l2tp_net_id; struct l2tp_net { - struct list_head l2tp_tunnel_list; - /* Lock for write access to l2tp_tunnel_list */ - spinlock_t l2tp_tunnel_list_lock; + /* Lock for write access to l2tp_tunnel_idr */ + spinlock_t l2tp_tunnel_idr_lock; + struct idr l2tp_tunnel_idr; struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; /* Lock for write access to l2tp_session_hlist */ spinlock_t l2tp_session_hlist_lock; @@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) struct l2tp_tunnel *tunnel; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - if (tunnel->tunnel_id == tunnel_id && - refcount_inc_not_zero(&tunnel->ref_count)) { - rcu_read_unlock_bh(); - - return tunnel; - } + tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id); + if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; } rcu_read_unlock_bh(); @@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get); struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth) { - const struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_net *pn = l2tp_pernet(net); + unsigned long tunnel_id, tmp; struct l2tp_tunnel *tunnel; int count = 0; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - if (++count > nth && + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && ++count > nth && refcount_inc_not_zero(&tunnel->ref_count)) { rcu_read_unlock_bh(); return tunnel; @@ -1043,7 +1041,7 @@ static int l2tp_xmit_core(struct l2tp_session *session, struct sk_buff *skb, uns IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED); nf_reset_ct(skb); - bh_lock_sock(sk); + bh_lock_sock_nested(sk); if (sock_owned_by_user(sk)) { kfree_skb(skb); ret = NET_XMIT_DROP; @@ -1150,8 +1148,10 @@ static void l2tp_tunnel_destruct(struct sock *sk) } /* Remove hooks into tunnel socket */ + write_lock_bh(&sk->sk_callback_lock); sk->sk_destruct = tunnel->old_sk_destruct; sk->sk_user_data = NULL; + write_unlock_bh(&sk->sk_callback_lock); /* Call the original destructor */ if (sk->sk_destruct) @@ -1225,6 +1225,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk) l2tp_tunnel_delete(tunnel); } +static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel) +{ + struct l2tp_net *pn = l2tp_pernet(net); + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); +} + /* Workqueue tunnel deletion function */ static void l2tp_tunnel_del_work(struct work_struct *work) { @@ -1232,7 +1241,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work) del_work); struct sock *sk = tunnel->sock; struct socket *sock = sk->sk_socket; - struct l2tp_net *pn; l2tp_tunnel_closeall(tunnel); @@ -1246,12 +1254,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work) } } - /* Remove the tunnel struct from the tunnel list */ - pn = l2tp_pernet(tunnel->l2tp_net); - spin_lock_bh(&pn->l2tp_tunnel_list_lock); - list_del_rcu(&tunnel->list); - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - + l2tp_tunnel_remove(tunnel->l2tp_net, tunnel); /* drop initial ref */ l2tp_tunnel_dec_refcount(tunnel); @@ -1382,8 +1385,6 @@ out: return err; } -static struct lock_class_key l2tp_socket_class; - int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id, struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp) { @@ -1453,12 +1454,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net, int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, struct l2tp_tunnel_cfg *cfg) { - struct l2tp_tunnel *tunnel_walk; - struct l2tp_net *pn; + struct l2tp_net *pn = l2tp_pernet(net); + u32 tunnel_id = tunnel->tunnel_id; struct socket *sock; struct sock *sk; int ret; + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id, + GFP_ATOMIC); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); + if (ret) + return ret == -ENOSPC ? -EEXIST : ret; + if (tunnel->fd < 0) { ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id, tunnel->peer_tunnel_id, cfg, @@ -1469,30 +1477,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, sock = sockfd_lookup(tunnel->fd, &ret); if (!sock) goto err; - - ret = l2tp_validate_socket(sock->sk, net, tunnel->encap); - if (ret < 0) - goto err_sock; } - tunnel->l2tp_net = net; - pn = l2tp_pernet(net); - sk = sock->sk; - sock_hold(sk); - tunnel->sock = sk; - - spin_lock_bh(&pn->l2tp_tunnel_list_lock); - list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) { - if (tunnel_walk->tunnel_id == tunnel->tunnel_id) { - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - sock_put(sk); - ret = -EEXIST; - goto err_sock; - } - } - list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list); - spin_unlock_bh(&pn->l2tp_tunnel_list_lock); + lock_sock(sk); + write_lock_bh(&sk->sk_callback_lock); + ret = l2tp_validate_socket(sk, net, tunnel->encap); + if (ret < 0) + goto err_inval_sock; + rcu_assign_sk_user_data(sk, tunnel); + write_unlock_bh(&sk->sk_callback_lock); if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { struct udp_tunnel_sock_cfg udp_cfg = { @@ -1503,15 +1497,20 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, }; setup_udp_tunnel_sock(net, sock, &udp_cfg); - } else { - sk->sk_user_data = tunnel; } tunnel->old_sk_destruct = sk->sk_destruct; sk->sk_destruct = &l2tp_tunnel_destruct; - lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, - "l2tp_sock"); sk->sk_allocation = GFP_ATOMIC; + release_sock(sk); + + sock_hold(sk); + tunnel->sock = sk; + tunnel->l2tp_net = net; + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); trace_register_tunnel(tunnel); @@ -1520,12 +1519,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, return 0; -err_sock: +err_inval_sock: + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + if (tunnel->fd < 0) sock_release(sock); else sockfd_put(sock); err: + l2tp_tunnel_remove(net, tunnel); return ret; } EXPORT_SYMBOL_GPL(l2tp_tunnel_register); @@ -1639,8 +1642,8 @@ static __net_init int l2tp_init_net(struct net *net) struct l2tp_net *pn = net_generic(net, l2tp_net_id); int hash; - INIT_LIST_HEAD(&pn->l2tp_tunnel_list); - spin_lock_init(&pn->l2tp_tunnel_list_lock); + idr_init(&pn->l2tp_tunnel_idr); + spin_lock_init(&pn->l2tp_tunnel_idr_lock); for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); @@ -1654,11 +1657,13 @@ static __net_exit void l2tp_exit_net(struct net *net) { struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_tunnel *tunnel = NULL; + unsigned long tunnel_id, tmp; int hash; rcu_read_lock_bh(); - list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { - l2tp_tunnel_delete(tunnel); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel) + l2tp_tunnel_delete(tunnel); } rcu_read_unlock_bh(); @@ -1668,6 +1673,7 @@ static __net_exit void l2tp_exit_net(struct net *net) for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); + idr_destroy(&pn->l2tp_tunnel_idr); } static struct pernet_operations l2tp_net_ops = { diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index 9dbd801ddb98..2478aa60145f 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -257,8 +257,6 @@ static void l2tp_ip6_destroy_sock(struct sock *sk) if (tunnel) l2tp_tunnel_delete(tunnel); - - inet6_destroy_sock(sk); } static int l2tp_ip6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) diff --git a/net/mac80211/agg-rx.c b/net/mac80211/agg-rx.c index 9414d3bbd65f..c6fa53230450 100644 --- a/net/mac80211/agg-rx.c +++ b/net/mac80211/agg-rx.c @@ -183,34 +183,15 @@ static void ieee80211_add_addbaext(struct ieee80211_sub_if_data *sdata, const struct ieee80211_addba_ext_ie *req, u16 buf_size) { - struct ieee80211_supported_band *sband; struct ieee80211_addba_ext_ie *resp; - const struct ieee80211_sta_he_cap *he_cap; - u8 frag_level, cap_frag_level; u8 *pos; - sband = ieee80211_get_sband(sdata); - if (!sband) - return; - he_cap = ieee80211_get_he_iftype_cap(sband, - ieee80211_vif_type_p2p(&sdata->vif)); - if (!he_cap) - return; - pos = skb_put_zero(skb, 2 + sizeof(struct ieee80211_addba_ext_ie)); *pos++ = WLAN_EID_ADDBA_EXT; *pos++ = sizeof(struct ieee80211_addba_ext_ie); resp = (struct ieee80211_addba_ext_ie *)pos; resp->data = req->data & IEEE80211_ADDBA_EXT_NO_FRAG; - frag_level = u32_get_bits(req->data, - IEEE80211_ADDBA_EXT_FRAG_LEVEL_MASK); - cap_frag_level = u32_get_bits(he_cap->he_cap_elem.mac_cap_info[0], - IEEE80211_HE_MAC_CAP0_DYNAMIC_FRAG_MASK); - if (frag_level > cap_frag_level) - frag_level = cap_frag_level; - resp->data |= u8_encode_bits(frag_level, - IEEE80211_ADDBA_EXT_FRAG_LEVEL_MASK); resp->data |= u8_encode_bits(buf_size >> IEEE80211_ADDBA_EXT_BUF_SIZE_SHIFT, IEEE80211_ADDBA_EXT_BUF_SIZE_MASK); } @@ -242,7 +223,7 @@ static void ieee80211_send_addba_resp(struct sta_info *sta, u8 *da, u16 tid, sdata->vif.type == NL80211_IFTYPE_MESH_POINT) memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); else if (sdata->vif.type == NL80211_IFTYPE_STATION) - memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(mgmt->bssid, sdata->vif.cfg.ap_addr, ETH_ALEN); else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) memcpy(mgmt->bssid, sdata->u.ibss.bssid, ETH_ALEN); @@ -297,9 +278,9 @@ void ___ieee80211_start_rx_ba_session(struct sta_info *sta, } if (!sta->sta.deflink.ht_cap.ht_supported && - sta->sdata->vif.bss_conf.chandef.chan->band != NL80211_BAND_6GHZ) { + !sta->sta.deflink.he_cap.has_he) { ht_dbg(sta->sdata, - "STA %pM erroneously requests BA session on tid %d w/o QoS\n", + "STA %pM erroneously requests BA session on tid %d w/o HT\n", sta->sta.addr, tid); /* send a response anyway, it's an error case if we get here */ goto end; diff --git a/net/mac80211/agg-tx.c b/net/mac80211/agg-tx.c index 07c892aa8c73..f9514bacbd4a 100644 --- a/net/mac80211/agg-tx.c +++ b/net/mac80211/agg-tx.c @@ -82,7 +82,7 @@ static void ieee80211_send_addba_request(struct ieee80211_sub_if_data *sdata, sdata->vif.type == NL80211_IFTYPE_MESH_POINT) memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); else if (sdata->vif.type == NL80211_IFTYPE_STATION) - memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(mgmt->bssid, sdata->vif.cfg.ap_addr, ETH_ALEN); else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) memcpy(mgmt->bssid, sdata->u.ibss.bssid, ETH_ALEN); @@ -491,7 +491,7 @@ void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid) { struct tid_ampdu_tx *tid_tx; struct ieee80211_local *local = sta->local; - struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_sub_if_data *sdata; struct ieee80211_ampdu_params params = { .sta = &sta->sta, .action = IEEE80211_AMPDU_TX_START, @@ -511,8 +511,6 @@ void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid) */ clear_bit(HT_AGG_STATE_WANT_START, &tid_tx->state); - ieee80211_agg_stop_txq(sta, tid); - /* * Make sure no packets are being processed. This ensures that * we have a valid starting sequence number and that in-flight @@ -521,6 +519,7 @@ void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid) */ synchronize_net(); + sdata = sta->sdata; params.ssn = sta->tid_seq[tid] >> 4; ret = drv_ampdu_action(local, sdata, ¶ms); tid_tx->ssn = params.ssn; @@ -534,6 +533,9 @@ void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid) */ set_bit(HT_AGG_STATE_DRV_READY, &tid_tx->state); } else if (ret) { + if (!sdata) + return; + ht_dbg(sdata, "BA request denied - HW unavailable for %pM tid %d\n", sta->sta.addr, tid); diff --git a/net/mac80211/airtime.c b/net/mac80211/airtime.c index 2e66598fac79..e8ebd343e2bf 100644 --- a/net/mac80211/airtime.c +++ b/net/mac80211/airtime.c @@ -452,6 +452,9 @@ static u32 ieee80211_get_rate_duration(struct ieee80211_hw *hw, (status->encoding == RX_ENC_HE && streams > 8))) return 0; + if (idx >= MCS_GROUP_RATES) + return 0; + duration = airtime_mcs_groups[group].duration[idx]; duration <<= airtime_mcs_groups[group].shift; *overhead = 36 + (streams << 2); diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c index 687b4c878d4a..672eff6f5d32 100644 --- a/net/mac80211/cfg.c +++ b/net/mac80211/cfg.c @@ -147,6 +147,7 @@ static int ieee80211_set_ap_mbssid_options(struct ieee80211_sub_if_data *sdata, link_conf->bssid_index = 0; link_conf->nontransmitted = false; link_conf->ema_ap = false; + link_conf->bssid_indicator = 0; if (sdata->vif.type != NL80211_IFTYPE_AP || !params.tx_wdev) return -EINVAL; @@ -576,7 +577,7 @@ static struct ieee80211_key * ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata, int link_id, u8 key_idx, bool pairwise, const u8 *mac_addr) { - struct ieee80211_local *local = sdata->local; + struct ieee80211_local *local __maybe_unused = sdata->local; struct ieee80211_link_data *link = &sdata->deflink; struct ieee80211_key *key; @@ -1511,6 +1512,12 @@ static int ieee80211_stop_ap(struct wiphy *wiphy, struct net_device *dev, kfree(link_conf->ftmr_params); link_conf->ftmr_params = NULL; + sdata->vif.mbssid_tx_vif = NULL; + link_conf->bssid_index = 0; + link_conf->nontransmitted = false; + link_conf->ema_ap = false; + link_conf->bssid_indicator = 0; + __sta_info_flush(sdata, true); ieee80211_free_keys(sdata, true); @@ -2554,47 +2561,50 @@ static int ieee80211_change_bss(struct wiphy *wiphy, struct bss_parameters *params) { struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link; struct ieee80211_supported_band *sband; u32 changed = 0; - if (!sdata_dereference(sdata->deflink.u.ap.beacon, sdata)) + link = ieee80211_link_or_deflink(sdata, params->link_id, true); + if (IS_ERR(link)) + return PTR_ERR(link); + + if (!sdata_dereference(link->u.ap.beacon, sdata)) return -ENOENT; - sband = ieee80211_get_sband(sdata); + sband = ieee80211_get_link_sband(link); if (!sband) return -EINVAL; if (params->use_cts_prot >= 0) { - sdata->vif.bss_conf.use_cts_prot = params->use_cts_prot; + link->conf->use_cts_prot = params->use_cts_prot; changed |= BSS_CHANGED_ERP_CTS_PROT; } if (params->use_short_preamble >= 0) { - sdata->vif.bss_conf.use_short_preamble = - params->use_short_preamble; + link->conf->use_short_preamble = params->use_short_preamble; changed |= BSS_CHANGED_ERP_PREAMBLE; } - if (!sdata->vif.bss_conf.use_short_slot && + if (!link->conf->use_short_slot && (sband->band == NL80211_BAND_5GHZ || sband->band == NL80211_BAND_6GHZ)) { - sdata->vif.bss_conf.use_short_slot = true; + link->conf->use_short_slot = true; changed |= BSS_CHANGED_ERP_SLOT; } if (params->use_short_slot_time >= 0) { - sdata->vif.bss_conf.use_short_slot = - params->use_short_slot_time; + link->conf->use_short_slot = params->use_short_slot_time; changed |= BSS_CHANGED_ERP_SLOT; } if (params->basic_rates) { - ieee80211_parse_bitrates(sdata->vif.bss_conf.chandef.width, + ieee80211_parse_bitrates(link->conf->chandef.width, wiphy->bands[sband->band], params->basic_rates, params->basic_rates_len, - &sdata->vif.bss_conf.basic_rates); + &link->conf->basic_rates); changed |= BSS_CHANGED_BASIC_RATES; - ieee80211_check_rate_mask(&sdata->deflink); + ieee80211_check_rate_mask(link); } if (params->ap_isolate >= 0) { @@ -2606,30 +2616,29 @@ static int ieee80211_change_bss(struct wiphy *wiphy, } if (params->ht_opmode >= 0) { - sdata->vif.bss_conf.ht_operation_mode = - (u16) params->ht_opmode; + link->conf->ht_operation_mode = (u16)params->ht_opmode; changed |= BSS_CHANGED_HT; } if (params->p2p_ctwindow >= 0) { - sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow &= + link->conf->p2p_noa_attr.oppps_ctwindow &= ~IEEE80211_P2P_OPPPS_CTWINDOW_MASK; - sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow |= + link->conf->p2p_noa_attr.oppps_ctwindow |= params->p2p_ctwindow & IEEE80211_P2P_OPPPS_CTWINDOW_MASK; changed |= BSS_CHANGED_P2P_PS; } if (params->p2p_opp_ps > 0) { - sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow |= + link->conf->p2p_noa_attr.oppps_ctwindow |= IEEE80211_P2P_OPPPS_ENABLE_BIT; changed |= BSS_CHANGED_P2P_PS; } else if (params->p2p_opp_ps == 0) { - sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow &= + link->conf->p2p_noa_attr.oppps_ctwindow &= ~IEEE80211_P2P_OPPPS_ENABLE_BIT; changed |= BSS_CHANGED_P2P_PS; } - ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + ieee80211_link_info_change_notify(sdata, link, changed); return 0; } @@ -4338,9 +4347,6 @@ static int ieee80211_get_txq_stats(struct wiphy *wiphy, struct ieee80211_sub_if_data *sdata; int ret = 0; - if (!local->ops->wake_tx_queue) - return 1; - spin_lock_bh(&local->fq.lock); rcu_read_lock(); diff --git a/net/mac80211/debugfs.c b/net/mac80211/debugfs.c index 78c7d60e8667..dfb9f55e2685 100644 --- a/net/mac80211/debugfs.c +++ b/net/mac80211/debugfs.c @@ -663,9 +663,7 @@ void debugfs_hw_add(struct ieee80211_local *local) DEBUGFS_ADD_MODE(force_tx_status, 0600); DEBUGFS_ADD_MODE(aql_enable, 0600); DEBUGFS_ADD(aql_pending); - - if (local->ops->wake_tx_queue) - DEBUGFS_ADD_MODE(aqm, 0600); + DEBUGFS_ADD_MODE(aqm, 0600); DEBUGFS_ADD_MODE(airtime_flags, 0600); diff --git a/net/mac80211/debugfs_netdev.c b/net/mac80211/debugfs_netdev.c index 5b014786fd2d..c87e1137e5da 100644 --- a/net/mac80211/debugfs_netdev.c +++ b/net/mac80211/debugfs_netdev.c @@ -677,8 +677,7 @@ static void add_common_files(struct ieee80211_sub_if_data *sdata) DEBUGFS_ADD(rc_rateidx_vht_mcs_mask_5ghz); DEBUGFS_ADD(hw_queues); - if (sdata->local->ops->wake_tx_queue && - sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && + if (sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && sdata->vif.type != NL80211_IFTYPE_NAN) DEBUGFS_ADD(aqm); } diff --git a/net/mac80211/debugfs_sta.c b/net/mac80211/debugfs_sta.c index d3397c1248d3..f1914bf39f0e 100644 --- a/net/mac80211/debugfs_sta.c +++ b/net/mac80211/debugfs_sta.c @@ -5,7 +5,7 @@ * Copyright 2007 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright(c) 2016 Intel Deutschland GmbH - * Copyright (C) 2018 - 2021 Intel Corporation + * Copyright (C) 2018 - 2022 Intel Corporation */ #include <linux/debugfs.h> @@ -167,7 +167,7 @@ static ssize_t sta_aqm_read(struct file *file, char __user *userbuf, continue; txqi = to_txq_info(sta->sta.txq[i]); p += scnprintf(p, bufsz + buf - p, - "%d %d %u %u %u %u %u %u %u %u %u 0x%lx(%s%s%s)\n", + "%d %d %u %u %u %u %u %u %u %u %u 0x%lx(%s%s%s%s)\n", txqi->txq.tid, txqi->txq.ac, txqi->tin.backlog_bytes, @@ -182,7 +182,8 @@ static ssize_t sta_aqm_read(struct file *file, char __user *userbuf, txqi->flags, test_bit(IEEE80211_TXQ_STOP, &txqi->flags) ? "STOP" : "RUN", test_bit(IEEE80211_TXQ_AMPDU, &txqi->flags) ? " AMPDU" : "", - test_bit(IEEE80211_TXQ_NO_AMSDU, &txqi->flags) ? " NO-AMSDU" : ""); + test_bit(IEEE80211_TXQ_NO_AMSDU, &txqi->flags) ? " NO-AMSDU" : "", + test_bit(IEEE80211_TXQ_DIRTY, &txqi->flags) ? " DIRTY" : ""); } rcu_read_unlock(); @@ -435,8 +436,29 @@ static ssize_t sta_agg_status_write(struct file *file, const char __user *userbu } STA_OPS_RW(agg_status); -static ssize_t sta_ht_capa_read(struct file *file, char __user *userbuf, - size_t count, loff_t *ppos) +/* link sta attributes */ +#define LINK_STA_OPS(name) \ +static const struct file_operations link_sta_ ##name## _ops = { \ + .read = link_sta_##name##_read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +static ssize_t link_sta_addr_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct link_sta_info *link_sta = file->private_data; + u8 mac[3 * ETH_ALEN + 1]; + + snprintf(mac, sizeof(mac), "%pM\n", link_sta->pub->addr); + + return simple_read_from_buffer(userbuf, count, ppos, mac, 3 * ETH_ALEN); +} + +LINK_STA_OPS(addr); + +static ssize_t link_sta_ht_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) { #define PRINT_HT_CAP(_cond, _str) \ do { \ @@ -446,8 +468,8 @@ static ssize_t sta_ht_capa_read(struct file *file, char __user *userbuf, char *buf, *p; int i; ssize_t bufsz = 512; - struct sta_info *sta = file->private_data; - struct ieee80211_sta_ht_cap *htc = &sta->sta.deflink.ht_cap; + struct link_sta_info *link_sta = file->private_data; + struct ieee80211_sta_ht_cap *htc = &link_sta->pub->ht_cap; ssize_t ret; buf = kzalloc(bufsz, GFP_KERNEL); @@ -524,14 +546,14 @@ static ssize_t sta_ht_capa_read(struct file *file, char __user *userbuf, kfree(buf); return ret; } -STA_OPS(ht_capa); +LINK_STA_OPS(ht_capa); -static ssize_t sta_vht_capa_read(struct file *file, char __user *userbuf, - size_t count, loff_t *ppos) +static ssize_t link_sta_vht_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) { char *buf, *p; - struct sta_info *sta = file->private_data; - struct ieee80211_sta_vht_cap *vhtc = &sta->sta.deflink.vht_cap; + struct link_sta_info *link_sta = file->private_data; + struct ieee80211_sta_vht_cap *vhtc = &link_sta->pub->vht_cap; ssize_t ret; ssize_t bufsz = 512; @@ -638,15 +660,15 @@ static ssize_t sta_vht_capa_read(struct file *file, char __user *userbuf, kfree(buf); return ret; } -STA_OPS(vht_capa); +LINK_STA_OPS(vht_capa); -static ssize_t sta_he_capa_read(struct file *file, char __user *userbuf, - size_t count, loff_t *ppos) +static ssize_t link_sta_he_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) { char *buf, *p; size_t buf_sz = PAGE_SIZE; - struct sta_info *sta = file->private_data; - struct ieee80211_sta_he_cap *hec = &sta->sta.deflink.he_cap; + struct link_sta_info *link_sta = file->private_data; + struct ieee80211_sta_he_cap *hec = &link_sta->pub->he_cap; struct ieee80211_he_mcs_nss_supp *nss = &hec->he_mcs_nss_supp; u8 ppe_size; u8 *cap; @@ -1011,7 +1033,7 @@ out: kfree(buf); return ret; } -STA_OPS(he_capa); +LINK_STA_OPS(he_capa); #define DEBUGFS_ADD(name) \ debugfs_create_file(#name, 0400, \ @@ -1048,18 +1070,11 @@ void ieee80211_sta_debugfs_add(struct sta_info *sta) DEBUGFS_ADD(num_ps_buf_frames); DEBUGFS_ADD(last_seq_ctrl); DEBUGFS_ADD(agg_status); - DEBUGFS_ADD(ht_capa); - DEBUGFS_ADD(vht_capa); - DEBUGFS_ADD(he_capa); - - DEBUGFS_ADD_COUNTER(rx_duplicates, deflink.rx_stats.num_duplicates); - DEBUGFS_ADD_COUNTER(rx_fragments, deflink.rx_stats.fragments); + /* FIXME: Kept here as the statistics are only done on the deflink */ DEBUGFS_ADD_COUNTER(tx_filtered, deflink.status_stats.filtered); - if (local->ops->wake_tx_queue) { - DEBUGFS_ADD(aqm); - DEBUGFS_ADD(airtime); - } + DEBUGFS_ADD(aqm); + DEBUGFS_ADD(airtime); if (wiphy_ext_feature_isset(local->hw.wiphy, NL80211_EXT_FEATURE_AQL)) @@ -1076,3 +1091,85 @@ void ieee80211_sta_debugfs_remove(struct sta_info *sta) debugfs_remove_recursive(sta->debugfs_dir); sta->debugfs_dir = NULL; } + +#undef DEBUGFS_ADD +#undef DEBUGFS_ADD_COUNTER + +#define DEBUGFS_ADD(name) \ + debugfs_create_file(#name, 0400, \ + link_sta->debugfs_dir, link_sta, &link_sta_ ##name## _ops) +#define DEBUGFS_ADD_COUNTER(name, field) \ + debugfs_create_ulong(#name, 0400, link_sta->debugfs_dir, &link_sta->field) + +void ieee80211_link_sta_debugfs_add(struct link_sta_info *link_sta) +{ + if (WARN_ON(!link_sta->sta->debugfs_dir)) + return; + + /* For non-MLO, leave the files in the main directory. */ + if (link_sta->sta->sta.valid_links) { + char link_dir_name[10]; + + snprintf(link_dir_name, sizeof(link_dir_name), + "link-%d", link_sta->link_id); + + link_sta->debugfs_dir = + debugfs_create_dir(link_dir_name, + link_sta->sta->debugfs_dir); + + DEBUGFS_ADD(addr); + } else { + if (WARN_ON(link_sta != &link_sta->sta->deflink)) + return; + + link_sta->debugfs_dir = link_sta->sta->debugfs_dir; + } + + DEBUGFS_ADD(ht_capa); + DEBUGFS_ADD(vht_capa); + DEBUGFS_ADD(he_capa); + + DEBUGFS_ADD_COUNTER(rx_duplicates, rx_stats.num_duplicates); + DEBUGFS_ADD_COUNTER(rx_fragments, rx_stats.fragments); +} + +void ieee80211_link_sta_debugfs_remove(struct link_sta_info *link_sta) +{ + if (!link_sta->debugfs_dir || !link_sta->sta->debugfs_dir) { + link_sta->debugfs_dir = NULL; + return; + } + + if (link_sta->debugfs_dir == link_sta->sta->debugfs_dir) { + WARN_ON(link_sta != &link_sta->sta->deflink); + link_sta->sta->debugfs_dir = NULL; + return; + } + + debugfs_remove_recursive(link_sta->debugfs_dir); + link_sta->debugfs_dir = NULL; +} + +void ieee80211_link_sta_debugfs_drv_add(struct link_sta_info *link_sta) +{ + if (WARN_ON(!link_sta->debugfs_dir)) + return; + + drv_link_sta_add_debugfs(link_sta->sta->local, link_sta->sta->sdata, + link_sta->pub, link_sta->debugfs_dir); +} + +void ieee80211_link_sta_debugfs_drv_remove(struct link_sta_info *link_sta) +{ + if (!link_sta->debugfs_dir) + return; + + if (WARN_ON(link_sta->debugfs_dir == link_sta->sta->debugfs_dir)) + return; + + /* Recreate the directory excluding the driver data */ + debugfs_remove_recursive(link_sta->debugfs_dir); + link_sta->debugfs_dir = NULL; + + ieee80211_link_sta_debugfs_add(link_sta); +} diff --git a/net/mac80211/debugfs_sta.h b/net/mac80211/debugfs_sta.h index d2e7c27ad6d1..cde8148bdb18 100644 --- a/net/mac80211/debugfs_sta.h +++ b/net/mac80211/debugfs_sta.h @@ -7,9 +7,21 @@ #ifdef CONFIG_MAC80211_DEBUGFS void ieee80211_sta_debugfs_add(struct sta_info *sta); void ieee80211_sta_debugfs_remove(struct sta_info *sta); + +void ieee80211_link_sta_debugfs_add(struct link_sta_info *link_sta); +void ieee80211_link_sta_debugfs_remove(struct link_sta_info *link_sta); + +void ieee80211_link_sta_debugfs_drv_add(struct link_sta_info *link_sta); +void ieee80211_link_sta_debugfs_drv_remove(struct link_sta_info *link_sta); #else static inline void ieee80211_sta_debugfs_add(struct sta_info *sta) {} static inline void ieee80211_sta_debugfs_remove(struct sta_info *sta) {} + +static inline void ieee80211_link_sta_debugfs_add(struct link_sta_info *link_sta) {} +static inline void ieee80211_link_sta_debugfs_remove(struct link_sta_info *link_sta) {} + +static inline void ieee80211_link_sta_debugfs_drv_add(struct link_sta_info *link_sta) {} +static inline void ieee80211_link_sta_debugfs_drv_remove(struct link_sta_info *link_sta) {} #endif #endif /* __MAC80211_DEBUGFS_STA_H */ diff --git a/net/mac80211/driver-ops.c b/net/mac80211/driver-ops.c index 5392ffa18270..cfb09e4aed4d 100644 --- a/net/mac80211/driver-ops.c +++ b/net/mac80211/driver-ops.c @@ -7,6 +7,7 @@ #include "ieee80211_i.h" #include "trace.h" #include "driver-ops.h" +#include "debugfs_sta.h" int drv_start(struct ieee80211_local *local) { @@ -391,6 +392,9 @@ int drv_ampdu_action(struct ieee80211_local *local, might_sleep(); + if (!sdata) + return -EIO; + sdata = get_bss_sdata(sdata); if (!check_sdata_in_driver(sdata)) return -EIO; @@ -497,6 +501,11 @@ int drv_change_sta_links(struct ieee80211_local *local, struct ieee80211_sta *sta, u16 old_links, u16 new_links) { + struct sta_info *info = container_of(sta, struct sta_info, sta); + struct link_sta_info *link_sta; + unsigned long links_to_add; + unsigned long links_to_rem; + unsigned int link_id; int ret = -EOPNOTSUPP; might_sleep(); @@ -510,11 +519,30 @@ int drv_change_sta_links(struct ieee80211_local *local, if (old_links == new_links) return 0; + links_to_add = ~old_links & new_links; + links_to_rem = old_links & ~new_links; + + for_each_set_bit(link_id, &links_to_rem, IEEE80211_MLD_MAX_NUM_LINKS) { + link_sta = rcu_dereference_protected(info->link[link_id], + lockdep_is_held(&local->sta_mtx)); + + ieee80211_link_sta_debugfs_drv_remove(link_sta); + } + trace_drv_change_sta_links(local, sdata, sta, old_links, new_links); if (local->ops->change_sta_links) ret = local->ops->change_sta_links(&local->hw, &sdata->vif, sta, old_links, new_links); trace_drv_return_int(local, ret); - return ret; + if (ret) + return ret; + + for_each_set_bit(link_id, &links_to_add, IEEE80211_MLD_MAX_NUM_LINKS) { + link_sta = rcu_dereference_protected(info->link[link_id], + lockdep_is_held(&local->sta_mtx)); + ieee80211_link_sta_debugfs_drv_add(link_sta); + } + + return 0; } diff --git a/net/mac80211/driver-ops.h b/net/mac80211/driver-ops.h index 81e40b0a3b16..5d13a3dfd366 100644 --- a/net/mac80211/driver-ops.h +++ b/net/mac80211/driver-ops.h @@ -480,6 +480,22 @@ static inline void drv_sta_add_debugfs(struct ieee80211_local *local, local->ops->sta_add_debugfs(&local->hw, &sdata->vif, sta, dir); } + +static inline void drv_link_sta_add_debugfs(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_sta *link_sta, + struct dentry *dir) +{ + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + if (local->ops->link_sta_add_debugfs) + local->ops->link_sta_add_debugfs(&local->hw, &sdata->vif, + link_sta, dir); +} #endif static inline void drv_sta_pre_rcu_remove(struct ieee80211_local *local, @@ -1183,7 +1199,7 @@ static inline void drv_wake_tx_queue(struct ieee80211_local *local, /* In reconfig don't transmit now, but mark for waking later */ if (local->in_reconfig) { - set_bit(IEEE80211_TXQ_STOP_NETIF_TX, &txq->flags); + set_bit(IEEE80211_TXQ_DIRTY, &txq->flags); return; } diff --git a/net/mac80211/ht.c b/net/mac80211/ht.c index 83bc41346ae7..5315ab750280 100644 --- a/net/mac80211/ht.c +++ b/net/mac80211/ht.c @@ -391,6 +391,37 @@ void ieee80211_ba_session_work(struct work_struct *work) tid_tx = sta->ampdu_mlme.tid_start_tx[tid]; if (!blocked && tid_tx) { + struct txq_info *txqi = to_txq_info(sta->sta.txq[tid]); + struct ieee80211_sub_if_data *sdata = + vif_to_sdata(txqi->txq.vif); + struct fq *fq = &sdata->local->fq; + + spin_lock_bh(&fq->lock); + + /* Allow only frags to be dequeued */ + set_bit(IEEE80211_TXQ_STOP, &txqi->flags); + + if (!skb_queue_empty(&txqi->frags)) { + /* Fragmented Tx is ongoing, wait for it to + * finish. Reschedule worker to retry later. + */ + + spin_unlock_bh(&fq->lock); + spin_unlock_bh(&sta->lock); + + /* Give the task working on the txq a chance + * to send out the queued frags + */ + synchronize_net(); + + mutex_unlock(&sta->ampdu_mlme.mtx); + + ieee80211_queue_work(&sdata->local->hw, work); + return; + } + + spin_unlock_bh(&fq->lock); + /* * Assign it over to the normal tid_tx array * where it "goes live". diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h index a842f2e1c230..d16606e84e22 100644 --- a/net/mac80211/ieee80211_i.h +++ b/net/mac80211/ieee80211_i.h @@ -390,6 +390,7 @@ struct ieee80211_mgd_auth_data { bool done, waiting; bool peer_confirmed; bool timeout_started; + int link_id; u8 ap_addr[ETH_ALEN] __aligned(2); @@ -412,6 +413,8 @@ struct ieee80211_mgd_assoc_data { u8 *elems; /* pointing to inside ie[] below */ ieee80211_conn_flags_t conn_flags; + + u16 status; } link[IEEE80211_MLD_MAX_NUM_LINKS]; u8 ap_addr[ETH_ALEN] __aligned(2); @@ -835,7 +838,7 @@ enum txq_info_flags { IEEE80211_TXQ_STOP, IEEE80211_TXQ_AMPDU, IEEE80211_TXQ_NO_AMSDU, - IEEE80211_TXQ_STOP_NETIF_TX, + IEEE80211_TXQ_DIRTY, }; /** @@ -1707,6 +1710,17 @@ struct ieee802_11_elems { u8 tx_pwr_env_num; u8 eht_cap_len; + /* mult-link element can be de-fragmented and thus u8 is not sufficient */ + size_t multi_link_len; + + /* + * store the per station profile pointer and length in case that the + * parsing also handled Multi-Link element parsing for a specific link + * ID. + */ + struct ieee80211_mle_per_sta_profile *prof; + size_t sta_prof_len; + /* whether a parse error occurred while retrieving these elements */ bool parse_error; @@ -2205,9 +2219,13 @@ static inline void ieee80211_tx_skb(struct ieee80211_sub_if_data *sdata, * represent a non-transmitting BSS in which case the data * for that non-transmitting BSS is returned * @link_id: the link ID to parse elements for, if a STA profile - * is present in the multi-link element, or -1 to ignore + * is present in the multi-link element, or -1 to ignore; + * note that the code currently assumes parsing an association + * (or re-association) response frame if this is given * @from_ap: frame is received from an AP (currently used only * for EHT capabilities parsing) + * @scratch_len: if non zero, specifies the requested length of the scratch + * buffer; otherwise, 'len' is used. */ struct ieee80211_elems_parse_params { const u8 *start; @@ -2218,6 +2236,7 @@ struct ieee80211_elems_parse_params { struct cfg80211_bss *bss; int link_id; bool from_ap; + size_t scratch_len; }; struct ieee802_11_elems * @@ -2288,7 +2307,6 @@ void ieee80211_wake_queue_by_reason(struct ieee80211_hw *hw, int queue, void ieee80211_stop_queue_by_reason(struct ieee80211_hw *hw, int queue, enum queue_stop_reason reason, bool refcounted); -void ieee80211_propagate_queue_wake(struct ieee80211_local *local, int queue); void ieee80211_add_pending_skb(struct ieee80211_local *local, struct sk_buff *skb); void ieee80211_add_pending_skbs(struct ieee80211_local *local, diff --git a/net/mac80211/iface.c b/net/mac80211/iface.c index dd9ac1f7d2ea..23ed13f15067 100644 --- a/net/mac80211/iface.c +++ b/net/mac80211/iface.c @@ -364,7 +364,9 @@ static int ieee80211_check_concurrent_iface(struct ieee80211_sub_if_data *sdata, /* No support for VLAN with MLO yet */ if (iftype == NL80211_IFTYPE_AP_VLAN && - nsdata->wdev.use_4addr) + sdata->wdev.use_4addr && + nsdata->vif.type == NL80211_IFTYPE_AP && + nsdata->vif.valid_links) return -EOPNOTSUPP; /* @@ -458,12 +460,6 @@ static void ieee80211_do_stop(struct ieee80211_sub_if_data *sdata, bool going_do if (cancel_scan) ieee80211_scan_cancel(local); - /* - * Stop TX on this interface first. - */ - if (!local->ops->wake_tx_queue && sdata->dev) - netif_tx_stop_all_queues(sdata->dev); - ieee80211_roc_purge(local, sdata); switch (sdata->vif.type) { @@ -811,13 +807,6 @@ static void ieee80211_uninit(struct net_device *dev) ieee80211_teardown_sdata(IEEE80211_DEV_TO_SUB_IF(dev)); } -static u16 ieee80211_netdev_select_queue(struct net_device *dev, - struct sk_buff *skb, - struct net_device *sb_dev) -{ - return ieee80211_select_queue(IEEE80211_DEV_TO_SUB_IF(dev), skb); -} - static void ieee80211_get_stats64(struct net_device *dev, struct rtnl_link_stats64 *stats) { @@ -831,7 +820,6 @@ static const struct net_device_ops ieee80211_dataif_ops = { .ndo_start_xmit = ieee80211_subif_start_xmit, .ndo_set_rx_mode = ieee80211_set_multicast_list, .ndo_set_mac_address = ieee80211_change_mac, - .ndo_select_queue = ieee80211_netdev_select_queue, .ndo_get_stats64 = ieee80211_get_stats64, }; @@ -939,7 +927,6 @@ static const struct net_device_ops ieee80211_dataif_8023_ops = { .ndo_start_xmit = ieee80211_subif_start_xmit_8023, .ndo_set_rx_mode = ieee80211_set_multicast_list, .ndo_set_mac_address = ieee80211_change_mac, - .ndo_select_queue = ieee80211_netdev_select_queue, .ndo_get_stats64 = ieee80211_get_stats64, .ndo_fill_forward_path = ieee80211_netdev_fill_forward_path, }; @@ -1441,35 +1428,6 @@ int ieee80211_do_open(struct wireless_dev *wdev, bool coming_up) ieee80211_recalc_ps(local); - if (sdata->vif.type == NL80211_IFTYPE_MONITOR || - sdata->vif.type == NL80211_IFTYPE_AP_VLAN || - local->ops->wake_tx_queue) { - /* XXX: for AP_VLAN, actually track AP queues */ - if (dev) - netif_tx_start_all_queues(dev); - } else if (dev) { - unsigned long flags; - int n_acs = IEEE80211_NUM_ACS; - int ac; - - if (local->hw.queues < IEEE80211_NUM_ACS) - n_acs = 1; - - spin_lock_irqsave(&local->queue_stop_reason_lock, flags); - if (sdata->vif.cab_queue == IEEE80211_INVAL_HW_QUEUE || - (local->queue_stop_reasons[sdata->vif.cab_queue] == 0 && - skb_queue_empty(&local->pending[sdata->vif.cab_queue]))) { - for (ac = 0; ac < n_acs; ac++) { - int ac_queue = sdata->vif.hw_queue[ac]; - - if (local->queue_stop_reasons[ac_queue] == 0 && - skb_queue_empty(&local->pending[ac_queue])) - netif_start_subqueue(dev, ac); - } - } - spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); - } - set_bit(SDATA_STATE_RUNNING, &sdata->state); return 0; @@ -1499,17 +1457,12 @@ static void ieee80211_if_setup(struct net_device *dev) { ether_setup(dev); dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->priv_flags |= IFF_NO_QUEUE; dev->netdev_ops = &ieee80211_dataif_ops; dev->needs_free_netdev = true; dev->priv_destructor = ieee80211_if_free; } -static void ieee80211_if_setup_no_queue(struct net_device *dev) -{ - ieee80211_if_setup(dev); - dev->priv_flags |= IFF_NO_QUEUE; -} - static void ieee80211_iface_process_skb(struct ieee80211_local *local, struct ieee80211_sub_if_data *sdata, struct sk_buff *skb) @@ -1898,8 +1851,7 @@ static int ieee80211_runtime_change_iftype(struct ieee80211_sub_if_data *sdata, ieee80211_stop_vif_queues(local, sdata, IEEE80211_QUEUE_STOP_REASON_IFTYPE_CHANGE); - synchronize_net(); - + /* do_stop will synchronize_rcu() first thing */ ieee80211_do_stop(sdata, false); ieee80211_teardown_sdata(sdata); @@ -2094,9 +2046,7 @@ int ieee80211_if_add(struct ieee80211_local *local, const char *name, struct net_device *ndev = NULL; struct ieee80211_sub_if_data *sdata = NULL; struct txq_info *txqi; - void (*if_setup)(struct net_device *dev); int ret, i; - int txqs = 1; ASSERT_RTNL(); @@ -2119,30 +2069,18 @@ int ieee80211_if_add(struct ieee80211_local *local, const char *name, sizeof(void *)); int txq_size = 0; - if (local->ops->wake_tx_queue && - type != NL80211_IFTYPE_AP_VLAN && + if (type != NL80211_IFTYPE_AP_VLAN && (type != NL80211_IFTYPE_MONITOR || (params->flags & MONITOR_FLAG_ACTIVE))) txq_size += sizeof(struct txq_info) + local->hw.txq_data_size; - if (local->ops->wake_tx_queue) { - if_setup = ieee80211_if_setup_no_queue; - } else { - if_setup = ieee80211_if_setup; - if (local->hw.queues >= IEEE80211_NUM_ACS) - txqs = IEEE80211_NUM_ACS; - } - ndev = alloc_netdev_mqs(size + txq_size, name, name_assign_type, - if_setup, txqs, 1); + ieee80211_if_setup, 1, 1); if (!ndev) return -ENOMEM; - if (!local->ops->wake_tx_queue && local->hw.wiphy->tx_queue_len) - ndev->tx_queue_len = local->hw.wiphy->tx_queue_len; - dev_net_set(ndev, wiphy_net(local->hw.wiphy)); ndev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); @@ -2242,6 +2180,7 @@ int ieee80211_if_add(struct ieee80211_local *local, const char *name, ndev->priv_flags |= IFF_LIVE_ADDR_CHANGE; ndev->hw_features |= ndev->features & MAC80211_SUPPORTED_FEATURES_TX; + sdata->vif.netdev_features = local->hw.netdev_features; netdev_set_default_ethtool_ops(ndev, &ieee80211_ethtool_ops); diff --git a/net/mac80211/link.c b/net/mac80211/link.c index e309708abae8..d1f5a9f7c647 100644 --- a/net/mac80211/link.c +++ b/net/mac80211/link.c @@ -357,6 +357,11 @@ static int _ieee80211_set_active_links(struct ieee80211_sub_if_data *sdata, list_for_each_entry(sta, &local->sta_list, list) { if (sdata != sta->sdata) continue; + + /* this is very temporary, but do it anyway */ + __ieee80211_sta_recalc_aggregates(sta, + old_active | active_links); + ret = drv_change_sta_links(local, sdata, &sta->sta, old_active, old_active | active_links); @@ -369,10 +374,22 @@ static int _ieee80211_set_active_links(struct ieee80211_sub_if_data *sdata, list_for_each_entry(sta, &local->sta_list, list) { if (sdata != sta->sdata) continue; + + __ieee80211_sta_recalc_aggregates(sta, active_links); + ret = drv_change_sta_links(local, sdata, &sta->sta, old_active | active_links, active_links); WARN_ON_ONCE(ret); + + /* + * Do it again, just in case - the driver might very + * well have called ieee80211_sta_recalc_aggregates() + * from there when filling in the new links, which + * would set it wrong since the vif's active links are + * not switched yet... + */ + __ieee80211_sta_recalc_aggregates(sta, active_links); } for_each_set_bit(link_id, &add, IEEE80211_MLD_MAX_NUM_LINKS) { diff --git a/net/mac80211/main.c b/net/mac80211/main.c index 02b5abc7326b..846528850612 100644 --- a/net/mac80211/main.c +++ b/net/mac80211/main.c @@ -630,7 +630,7 @@ struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, if (WARN_ON(!ops->tx || !ops->start || !ops->stop || !ops->config || !ops->add_interface || !ops->remove_interface || - !ops->configure_filter)) + !ops->configure_filter || !ops->wake_tx_queue)) return NULL; if (WARN_ON(ops->sta_state && (ops->sta_add || ops->sta_remove))) @@ -719,9 +719,7 @@ struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, if (!ops->set_key) wiphy->flags |= WIPHY_FLAG_IBSS_RSN; - if (ops->wake_tx_queue) - wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_TXQS); - + wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_TXQS); wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_RRM); wiphy->bss_priv_size = sizeof(struct ieee80211_bss); @@ -834,10 +832,7 @@ struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, atomic_set(&local->agg_queue_stop[i], 0); } tasklet_setup(&local->tx_pending_tasklet, ieee80211_tx_pending); - - if (ops->wake_tx_queue) - tasklet_setup(&local->wake_txqs_tasklet, ieee80211_wake_txqs); - + tasklet_setup(&local->wake_txqs_tasklet, ieee80211_wake_txqs); tasklet_setup(&local->tasklet, ieee80211_tasklet_handler); skb_queue_head_init(&local->skb_queue); @@ -1087,6 +1082,16 @@ int ieee80211_register_hw(struct ieee80211_hw *hw) channels += sband->n_channels; + /* + * Due to the way the aggregation code handles this and it + * being an HT capability, we can't really support delayed + * BA in MLO (yet). + */ + if (WARN_ON(sband->ht_cap.ht_supported && + (sband->ht_cap.cap & IEEE80211_HT_CAP_DELAY_BA) && + hw->wiphy->flags & WIPHY_FLAG_SUPPORTS_MLO)) + return -EINVAL; + if (max_bitrates < sband->n_bitrates) max_bitrates = sband->n_bitrates; supp_ht = supp_ht || sband->ht_cap.ht_supported; @@ -1155,6 +1160,8 @@ int ieee80211_register_hw(struct ieee80211_hw *hw) if (!local->int_scan_req) return -ENOMEM; + eth_broadcast_addr(local->int_scan_req->bssid); + for (band = 0; band < NUM_NL80211_BANDS; band++) { if (!local->hw.wiphy->bands[band]) continue; diff --git a/net/mac80211/mesh_pathtbl.c b/net/mac80211/mesh_pathtbl.c index 69d5e1ec6ede..3b81e6df3f34 100644 --- a/net/mac80211/mesh_pathtbl.c +++ b/net/mac80211/mesh_pathtbl.c @@ -512,7 +512,7 @@ static void mesh_path_free_rcu(struct mesh_table *tbl, mpath->flags |= MESH_PATH_RESOLVING | MESH_PATH_DELETED; mesh_gate_del(tbl, mpath); spin_unlock_bh(&mpath->state_lock); - del_timer_sync(&mpath->timer); + timer_shutdown_sync(&mpath->timer); atomic_dec(&sdata->u.mesh.mpaths); atomic_dec(&tbl->entries); mesh_path_flush_pending(mpath); diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c index d8484cd870de..0aee2392dd29 100644 --- a/net/mac80211/mlme.c +++ b/net/mac80211/mlme.c @@ -2717,18 +2717,10 @@ static u32 ieee80211_link_set_associated(struct ieee80211_link_data *link, } if (link->u.mgd.have_beacon) { - /* - * If the AP is buggy we may get here with no DTIM period - * known, so assume it's 1 which is the only safe assumption - * in that case, although if the TIM IE is broken powersave - * probably just won't work at all. - */ - bss_conf->dtim_period = link->u.mgd.dtim_period ?: 1; bss_conf->beacon_rate = bss->beacon_rate; changed |= BSS_CHANGED_BEACON_INFO; } else { bss_conf->beacon_rate = NULL; - bss_conf->dtim_period = 0; } /* Tell the driver to monitor connection quality (if supported) */ @@ -2754,7 +2746,8 @@ static void ieee80211_set_associated(struct ieee80211_sub_if_data *sdata, struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; struct ieee80211_link_data *link; - if (!cbss) + if (!cbss || + assoc_data->link[link_id].status != WLAN_STATUS_SUCCESS) continue; link = sdata_dereference(sdata->link[link_id], sdata); @@ -2782,7 +2775,8 @@ static void ieee80211_set_associated(struct ieee80211_sub_if_data *sdata, struct ieee80211_link_data *link; struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; - if (!cbss) + if (!cbss || + assoc_data->link[link_id].status != WLAN_STATUS_SUCCESS) continue; link = sdata_dereference(sdata->link[link_id], sdata); @@ -3868,9 +3862,15 @@ static void ieee80211_get_rates(struct ieee80211_supported_band *sband, } } -static bool ieee80211_twt_req_supported(const struct link_sta_info *link_sta, +static bool ieee80211_twt_req_supported(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct link_sta_info *link_sta, const struct ieee802_11_elems *elems) { + const struct ieee80211_sta_he_cap *own_he_cap = + ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + if (elems->ext_capab_len < 10) return false; @@ -3878,14 +3878,19 @@ static bool ieee80211_twt_req_supported(const struct link_sta_info *link_sta, return false; return link_sta->pub->he_cap.he_cap_elem.mac_cap_info[0] & - IEEE80211_HE_MAC_CAP0_TWT_RES; + IEEE80211_HE_MAC_CAP0_TWT_RES && + own_he_cap && + (own_he_cap->he_cap_elem.mac_cap_info[0] & + IEEE80211_HE_MAC_CAP0_TWT_REQ); } -static int ieee80211_recalc_twt_req(struct ieee80211_link_data *link, +static int ieee80211_recalc_twt_req(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + struct ieee80211_link_data *link, struct link_sta_info *link_sta, struct ieee802_11_elems *elems) { - bool twt = ieee80211_twt_req_supported(link_sta, elems); + bool twt = ieee80211_twt_req_supported(sdata, sband, link_sta, elems); if (link->conf->twt_requester != twt) { link->conf->twt_requester = twt; @@ -3923,11 +3928,11 @@ static bool ieee80211_assoc_config_link(struct ieee80211_link_data *link, struct ieee80211_mgd_assoc_data *assoc_data = sdata->u.mgd.assoc_data; struct ieee80211_bss_conf *bss_conf = link->conf; struct ieee80211_local *local = sdata->local; + unsigned int link_id = link->link_id; struct ieee80211_elems_parse_params parse_params = { .start = elem_start, .len = elem_len, - .bss = cbss, - .link_id = link == &sdata->deflink ? -1 : link->link_id, + .link_id = link_id == assoc_data->assoc_link_id ? -1 : link_id, .from_ap = true, }; bool is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ; @@ -3942,8 +3947,35 @@ static bool ieee80211_assoc_config_link(struct ieee80211_link_data *link, if (!elems) return false; - /* FIXME: use from STA profile element after parsing that */ - capab_info = le16_to_cpu(mgmt->u.assoc_resp.capab_info); + if (link_id == assoc_data->assoc_link_id) { + capab_info = le16_to_cpu(mgmt->u.assoc_resp.capab_info); + + /* + * we should not get to this flow unless the association was + * successful, so set the status directly to success + */ + assoc_data->link[link_id].status = WLAN_STATUS_SUCCESS; + } else if (!elems->prof) { + ret = false; + goto out; + } else { + const u8 *ptr = elems->prof->variable + + elems->prof->sta_info_len - 1; + + /* + * During parsing, we validated that these fields exist, + * otherwise elems->prof would have been set to NULL. + */ + capab_info = get_unaligned_le16(ptr); + assoc_data->link[link_id].status = get_unaligned_le16(ptr + 2); + + if (assoc_data->link[link_id].status != WLAN_STATUS_SUCCESS) { + link_info(link, "association response status code=%u\n", + assoc_data->link[link_id].status); + ret = true; + goto out; + } + } if (!is_s1g && !elems->supp_rates) { sdata_info(sdata, "no SuppRates element in AssocResp\n"); @@ -3984,6 +4016,7 @@ static bool ieee80211_assoc_config_link(struct ieee80211_link_data *link, parse_params.start = bss_ies->data; parse_params.len = bss_ies->len; + parse_params.bss = cbss; bss_elems = ieee802_11_parse_elems_full(&parse_params); if (!bss_elems) { ret = false; @@ -4099,7 +4132,8 @@ static bool ieee80211_assoc_config_link(struct ieee80211_link_data *link, else bss_conf->twt_protected = false; - *changed |= ieee80211_recalc_twt_req(link, link_sta, elems); + *changed |= ieee80211_recalc_twt_req(sdata, sband, link, + link_sta, elems); if (elems->eht_operation && elems->eht_cap && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_EHT)) { @@ -4864,6 +4898,7 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, unsigned int link_id; struct sta_info *sta; u64 changed[IEEE80211_MLD_MAX_NUM_LINKS] = {}; + u16 valid_links = 0; int err; mutex_lock(&sdata->local->sta_mtx); @@ -4876,8 +4911,6 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, goto out_err; if (sdata->vif.valid_links) { - u16 valid_links = 0; - for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { if (!assoc_data->link[link_id].bss) continue; @@ -4894,10 +4927,11 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, } for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; struct ieee80211_link_data *link; struct link_sta_info *link_sta; - if (!assoc_data->link[link_id].bss) + if (!cbss) continue; link = sdata_dereference(sdata->link[link_id], sdata); @@ -4906,28 +4940,36 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, if (sdata->vif.valid_links) link_info(link, - "local address %pM, AP link address %pM\n", + "local address %pM, AP link address %pM%s\n", link->conf->addr, - assoc_data->link[link_id].bss->bssid); + assoc_data->link[link_id].bss->bssid, + link_id == assoc_data->assoc_link_id ? + " (assoc)" : ""); link_sta = rcu_dereference_protected(sta->link[link_id], lockdep_is_held(&local->sta_mtx)); if (WARN_ON(!link_sta)) goto out_err; - if (link_id != assoc_data->assoc_link_id) { - struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + if (!link->u.mgd.have_beacon) { const struct cfg80211_bss_ies *ies; rcu_read_lock(); - ies = rcu_dereference(cbss->ies); + ies = rcu_dereference(cbss->beacon_ies); + if (ies) + link->u.mgd.have_beacon = true; + else + ies = rcu_dereference(cbss->ies); ieee80211_get_dtim(ies, &link->conf->sync_dtim_count, &link->u.mgd.dtim_period); - link->conf->dtim_period = link->u.mgd.dtim_period ?: 1; link->conf->beacon_int = cbss->beacon_interval; rcu_read_unlock(); + } + + link->conf->dtim_period = link->u.mgd.dtim_period ?: 1; + if (link_id != assoc_data->assoc_link_id) { err = ieee80211_prep_channel(sdata, link, cbss, &link->u.mgd.conn_flags); if (err) { @@ -4947,6 +4989,12 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, &changed[link_id])) goto out_err; + if (assoc_data->link[link_id].status != WLAN_STATUS_SUCCESS) { + valid_links &= ~BIT(link_id); + ieee80211_sta_remove_link(sta, link_id); + continue; + } + if (link_id != assoc_data->assoc_link_id) { err = ieee80211_sta_activate_link(sta, link_id); if (err) @@ -4954,6 +5002,9 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, } } + /* links might have changed due to rejected ones, set them again */ + ieee80211_vif_set_links(sdata, valid_links); + rate_control_rate_init(sta); if (ifmgd->flags & IEEE80211_STA_MFP_ENABLED) { @@ -5033,6 +5084,7 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, struct cfg80211_rx_assoc_resp resp = { .uapsd_queues = -1, }; + u8 ap_mld_addr[ETH_ALEN] __aligned(2); unsigned int link_id; sdata_assert_lock(sdata); @@ -5187,10 +5239,13 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, link = sdata_dereference(sdata->link[link_id], sdata); if (!link) continue; + if (!assoc_data->link[link_id].bss) continue; + resp.links[link_id].bss = assoc_data->link[link_id].bss; resp.links[link_id].addr = link->conf->addr; + resp.links[link_id].status = assoc_data->link[link_id].status; /* get uapsd queues configuration - same for all links */ resp.uapsd_queues = 0; @@ -5199,6 +5254,11 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, resp.uapsd_queues |= ieee80211_ac_to_qos_mask[ac]; } + if (sdata->vif.valid_links) { + ether_addr_copy(ap_mld_addr, sdata->vif.cfg.ap_addr); + resp.ap_mld_addr = ap_mld_addr; + } + ieee80211_destroy_assoc_data(sdata, status_code == WLAN_STATUS_SUCCESS ? ASSOC_SUCCESS : @@ -5208,8 +5268,6 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, resp.len = len; resp.req_ies = ifmgd->assoc_req_ies; resp.req_ies_len = ifmgd->assoc_req_ies_len; - if (sdata->vif.valid_links) - resp.ap_mld_addr = sdata->vif.cfg.ap_addr; cfg80211_rx_assoc_resp(sdata->dev, &resp); notify_driver: drv_mgd_complete_tx(sdata->local, sdata, &info); @@ -5432,6 +5490,7 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_link_data *link, struct ieee802_11_elems *elems; struct ieee80211_local *local = sdata->local; struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_supported_band *sband; struct ieee80211_channel *chan; struct link_sta_info *link_sta; struct sta_info *sta; @@ -5694,7 +5753,12 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_link_data *link, goto free; } - changed |= ieee80211_recalc_twt_req(link, link_sta, elems); + if (WARN_ON(!link->conf->chandef.chan)) + goto free; + + sband = local->hw.wiphy->bands[link->conf->chandef.chan->band]; + + changed |= ieee80211_recalc_twt_req(sdata, sband, link, link_sta, elems); if (ieee80211_config_bw(link, elems->ht_cap_elem, elems->vht_cap_elem, elems->ht_operation, @@ -6640,6 +6704,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, req->ap_mld_addr ?: req->bss->bssid, ETH_ALEN); auth_data->bss = req->bss; + auth_data->link_id = req->link_id; if (req->auth_data_len >= 4) { if (req->auth_type == NL80211_AUTHTYPE_SAE) { @@ -6658,7 +6723,8 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, * removal and re-addition of the STA entry in * ieee80211_prep_connection(). */ - cont_auth = ifmgd->auth_data && req->bss == ifmgd->auth_data->bss; + cont_auth = ifmgd->auth_data && req->bss == ifmgd->auth_data->bss && + ifmgd->auth_data->link_id == req->link_id; if (req->ie && req->ie_len) { memcpy(&auth_data->data[auth_data->data_len], @@ -6982,7 +7048,8 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata, /* keep sta info, bssid if matching */ match = ether_addr_equal(ifmgd->auth_data->ap_addr, - assoc_data->ap_addr); + assoc_data->ap_addr) && + ifmgd->auth_data->link_id == req->link_id; ieee80211_destroy_auth_data(sdata, match); } diff --git a/net/mac80211/rc80211_minstrel_ht.c b/net/mac80211/rc80211_minstrel_ht.c index 3d91b98db099..762346598338 100644 --- a/net/mac80211/rc80211_minstrel_ht.c +++ b/net/mac80211/rc80211_minstrel_ht.c @@ -1963,9 +1963,6 @@ minstrel_ht_alloc(struct ieee80211_hw *hw) /* safe default, does not necessarily have to match hw properties */ mp->max_retry = 7; - if (hw->max_rates >= 4) - mp->has_mrr = true; - mp->hw = hw; mp->update_interval = HZ / 20; diff --git a/net/mac80211/rc80211_minstrel_ht.h b/net/mac80211/rc80211_minstrel_ht.h index 1766ff0c78d3..4be0401f7721 100644 --- a/net/mac80211/rc80211_minstrel_ht.h +++ b/net/mac80211/rc80211_minstrel_ht.h @@ -74,7 +74,6 @@ struct minstrel_priv { struct ieee80211_hw *hw; - bool has_mrr; unsigned int cw_min; unsigned int cw_max; unsigned int max_retry; diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c index f99416d2e144..c6562a6d2503 100644 --- a/net/mac80211/rx.c +++ b/net/mac80211/rx.c @@ -1571,9 +1571,6 @@ static void sta_ps_start(struct sta_info *sta) ieee80211_clear_fast_xmit(sta); - if (!sta->sta.txq[0]) - return; - for (tid = 0; tid < IEEE80211_NUM_TIDS; tid++) { struct ieee80211_txq *txq = sta->sta.txq[tid]; struct txq_info *txqi = to_txq_info(txq); @@ -2406,7 +2403,6 @@ static int ieee80211_802_1x_port_control(struct ieee80211_rx_data *rx) static int ieee80211_drop_unencrypted(struct ieee80211_rx_data *rx, __le16 fc) { - struct ieee80211_hdr *hdr = (void *)rx->skb->data; struct sk_buff *skb = rx->skb; struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); @@ -2417,31 +2413,6 @@ static int ieee80211_drop_unencrypted(struct ieee80211_rx_data *rx, __le16 fc) if (status->flag & RX_FLAG_DECRYPTED) return 0; - /* check mesh EAPOL frames first */ - if (unlikely(rx->sta && ieee80211_vif_is_mesh(&rx->sdata->vif) && - ieee80211_is_data(fc))) { - struct ieee80211s_hdr *mesh_hdr; - u16 hdr_len = ieee80211_hdrlen(fc); - u16 ethertype_offset; - __be16 ethertype; - - if (!ether_addr_equal(hdr->addr1, rx->sdata->vif.addr)) - goto drop_check; - - /* make sure fixed part of mesh header is there, also checks skb len */ - if (!pskb_may_pull(rx->skb, hdr_len + 6)) - goto drop_check; - - mesh_hdr = (struct ieee80211s_hdr *)(skb->data + hdr_len); - ethertype_offset = hdr_len + ieee80211_get_mesh_hdrlen(mesh_hdr) + - sizeof(rfc1042_header); - - if (skb_copy_bits(rx->skb, ethertype_offset, ðertype, 2) == 0 && - ethertype == rx->sdata->control_port_protocol) - return 0; - } - -drop_check: /* Drop unencrypted frames if key is set. */ if (unlikely(!ieee80211_has_protected(fc) && !ieee80211_is_any_nullfunc(fc) && @@ -2895,8 +2866,16 @@ ieee80211_rx_h_mesh_fwding(struct ieee80211_rx_data *rx) hdr = (struct ieee80211_hdr *) skb->data; mesh_hdr = (struct ieee80211s_hdr *) (skb->data + hdrlen); - if (ieee80211_drop_unencrypted(rx, hdr->frame_control)) - return RX_DROP_MONITOR; + if (ieee80211_drop_unencrypted(rx, hdr->frame_control)) { + int offset = hdrlen + ieee80211_get_mesh_hdrlen(mesh_hdr) + + sizeof(rfc1042_header); + __be16 ethertype; + + if (!ether_addr_equal(hdr->addr1, rx->sdata->vif.addr) || + skb_copy_bits(rx->skb, offset, ðertype, 2) != 0 || + ethertype != rx->sdata->control_port_protocol) + return RX_DROP_MONITOR; + } /* frame is in RMC, don't forward */ if (ieee80211_is_data(hdr->frame_control) && @@ -4070,6 +4049,58 @@ static void ieee80211_invoke_rx_handlers(struct ieee80211_rx_data *rx) #undef CALL_RXH } +static bool +ieee80211_rx_is_valid_sta_link_id(struct ieee80211_sta *sta, u8 link_id) +{ + if (!sta->mlo) + return false; + + return !!(sta->valid_links & BIT(link_id)); +} + +static bool ieee80211_rx_data_set_link(struct ieee80211_rx_data *rx, + u8 link_id) +{ + rx->link_id = link_id; + rx->link = rcu_dereference(rx->sdata->link[link_id]); + + if (!rx->sta) + return rx->link; + + if (!ieee80211_rx_is_valid_sta_link_id(&rx->sta->sta, link_id)) + return false; + + rx->link_sta = rcu_dereference(rx->sta->link[link_id]); + + return rx->link && rx->link_sta; +} + +static bool ieee80211_rx_data_set_sta(struct ieee80211_rx_data *rx, + struct ieee80211_sta *pubsta, + int link_id) +{ + struct sta_info *sta; + + sta = container_of(pubsta, struct sta_info, sta); + + rx->link_id = link_id; + rx->sta = sta; + + if (sta) { + rx->local = sta->sdata->local; + if (!rx->sdata) + rx->sdata = sta->sdata; + rx->link_sta = &sta->deflink; + } + + if (link_id < 0) + rx->link = &rx->sdata->deflink; + else if (!ieee80211_rx_data_set_link(rx, link_id)) + return false; + + return true; +} + /* * This function makes calls into the RX path, therefore * it has to be invoked under RCU read lock. @@ -4078,16 +4109,19 @@ void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid) { struct sk_buff_head frames; struct ieee80211_rx_data rx = { - .sta = sta, - .sdata = sta->sdata, - .local = sta->local, /* This is OK -- must be QoS data frame */ .security_idx = tid, .seqno_idx = tid, - .link_id = -1, }; struct tid_ampdu_rx *tid_agg_rx; - u8 link_id; + int link_id = -1; + + /* FIXME: statistics won't be right with this */ + if (sta->sta.valid_links) + link_id = ffs(sta->sta.valid_links) - 1; + + if (!ieee80211_rx_data_set_sta(&rx, &sta->sta, link_id)) + return; tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); if (!tid_agg_rx) @@ -4107,10 +4141,6 @@ void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid) }; drv_event_callback(rx.local, rx.sdata, &event); } - /* FIXME: statistics won't be right with this */ - link_id = sta->sta.valid_links ? ffs(sta->sta.valid_links) - 1 : 0; - rx.link = rcu_dereference(sta->sdata->link[link_id]); - rx.link_sta = rcu_dereference(sta->link[link_id]); ieee80211_rx_handlers(&rx, &frames); } @@ -4126,7 +4156,6 @@ void ieee80211_mark_rx_ba_filtered_frames(struct ieee80211_sta *pubsta, u8 tid, /* This is OK -- must be QoS data frame */ .security_idx = tid, .seqno_idx = tid, - .link_id = -1, }; int i, diff; @@ -4137,10 +4166,8 @@ void ieee80211_mark_rx_ba_filtered_frames(struct ieee80211_sta *pubsta, u8 tid, sta = container_of(pubsta, struct sta_info, sta); - rx.sta = sta; - rx.sdata = sta->sdata; - rx.link = &rx.sdata->deflink; - rx.local = sta->local; + if (!ieee80211_rx_data_set_sta(&rx, pubsta, -1)) + return; rcu_read_lock(); tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); @@ -4527,15 +4554,6 @@ void ieee80211_check_fast_rx_iface(struct ieee80211_sub_if_data *sdata) mutex_unlock(&local->sta_mtx); } -static bool -ieee80211_rx_is_valid_sta_link_id(struct ieee80211_sta *sta, u8 link_id) -{ - if (!sta->mlo) - return false; - - return !!(sta->valid_links & BIT(link_id)); -} - static void ieee80211_rx_8023(struct ieee80211_rx_data *rx, struct ieee80211_fast_rx *fast_rx, int orig_len) @@ -4646,7 +4664,6 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx, struct sk_buff *skb = rx->skb; struct ieee80211_hdr *hdr = (void *)skb->data; struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); - struct sta_info *sta = rx->sta; int orig_len = skb->len; int hdrlen = ieee80211_hdrlen(hdr->frame_control); int snap_offs = hdrlen; @@ -4658,7 +4675,6 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx, u8 da[ETH_ALEN]; u8 sa[ETH_ALEN]; } addrs __aligned(2); - struct link_sta_info *link_sta; struct ieee80211_sta_rx_stats *stats; /* for parallel-rx, we need to have DUP_VALIDATED, otherwise we write @@ -4761,18 +4777,10 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx, drop: dev_kfree_skb(skb); - if (rx->link_id >= 0) { - link_sta = rcu_dereference(sta->link[rx->link_id]); - if (!link_sta) - return true; - } else { - link_sta = &sta->deflink; - } - if (fast_rx->uses_rss) - stats = this_cpu_ptr(link_sta->pcpu_rx_stats); + stats = this_cpu_ptr(rx->link_sta->pcpu_rx_stats); else - stats = &link_sta->rx_stats; + stats = &rx->link_sta->rx_stats; stats->dropped++; return true; @@ -4790,8 +4798,8 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx, struct ieee80211_local *local = rx->local; struct ieee80211_sub_if_data *sdata = rx->sdata; struct ieee80211_hdr *hdr = (void *)skb->data; - struct link_sta_info *link_sta = NULL; - struct ieee80211_link_data *link; + struct link_sta_info *link_sta = rx->link_sta; + struct ieee80211_link_data *link = rx->link; rx->skb = skb; @@ -4813,35 +4821,6 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx, if (!ieee80211_accept_frame(rx)) return false; - if (rx->link_id >= 0) { - link = rcu_dereference(rx->sdata->link[rx->link_id]); - - /* we might race link removal */ - if (!link) - return true; - rx->link = link; - - if (rx->sta) { - rx->link_sta = - rcu_dereference(rx->sta->link[rx->link_id]); - if (!rx->link_sta) - return true; - } - } else { - if (rx->sta) - rx->link_sta = &rx->sta->deflink; - - rx->link = &sdata->deflink; - } - - if (unlikely(!is_multicast_ether_addr(hdr->addr1) && - rx->link_id >= 0 && rx->sta && rx->sta->sta.mlo)) { - link_sta = rcu_dereference(rx->sta->link[rx->link_id]); - - if (WARN_ON_ONCE(!link_sta)) - return true; - } - if (!consume) { struct skb_shared_hwtstamps *shwt; @@ -4859,9 +4838,12 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx, */ shwt = skb_hwtstamps(rx->skb); shwt->hwtstamp = skb_hwtstamps(skb)->hwtstamp; + + /* Update the hdr pointer to the new skb for translation below */ + hdr = (struct ieee80211_hdr *)rx->skb->data; } - if (unlikely(link_sta)) { + if (unlikely(rx->sta && rx->sta->sta.mlo)) { /* translate to MLD addresses */ if (ether_addr_equal(link->conf->addr, hdr->addr1)) ether_addr_copy(hdr->addr1, rx->sdata->vif.addr); @@ -4891,6 +4873,7 @@ static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw, struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); struct ieee80211_fast_rx *fast_rx; struct ieee80211_rx_data rx; + int link_id = -1; memset(&rx, 0, sizeof(rx)); rx.skb = skb; @@ -4907,12 +4890,8 @@ static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw, if (!pubsta) goto drop; - rx.sta = container_of(pubsta, struct sta_info, sta); - rx.sdata = rx.sta->sdata; - - if (status->link_valid && - !ieee80211_rx_is_valid_sta_link_id(pubsta, status->link_id)) - goto drop; + if (status->link_valid) + link_id = status->link_id; /* * TODO: Should the frame be dropped if the right link_id is not @@ -4921,19 +4900,8 @@ static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw, * link_id is used only for stats purpose and updating the stats on * the deflink is fine? */ - if (status->link_valid) - rx.link_id = status->link_id; - - if (rx.link_id >= 0) { - struct ieee80211_link_data *link; - - link = rcu_dereference(rx.sdata->link[rx.link_id]); - if (!link) - goto drop; - rx.link = link; - } else { - rx.link = &rx.sdata->deflink; - } + if (!ieee80211_rx_data_set_sta(&rx, pubsta, link_id)) + goto drop; fast_rx = rcu_dereference(rx.sta->fast_rx); if (!fast_rx) @@ -4951,6 +4919,8 @@ static bool ieee80211_rx_for_interface(struct ieee80211_rx_data *rx, { struct link_sta_info *link_sta; struct ieee80211_hdr *hdr = (void *)skb->data; + struct sta_info *sta; + int link_id = -1; /* * Look up link station first, in case there's a @@ -4960,24 +4930,19 @@ static bool ieee80211_rx_for_interface(struct ieee80211_rx_data *rx, */ link_sta = link_sta_info_get_bss(rx->sdata, hdr->addr2); if (link_sta) { - rx->sta = link_sta->sta; - rx->link_id = link_sta->link_id; + sta = link_sta->sta; + link_id = link_sta->link_id; } else { struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); - rx->sta = sta_info_get_bss(rx->sdata, hdr->addr2); - if (rx->sta) { - if (status->link_valid && - !ieee80211_rx_is_valid_sta_link_id(&rx->sta->sta, - status->link_id)) - return false; - - rx->link_id = status->link_valid ? status->link_id : -1; - } else { - rx->link_id = -1; - } + sta = sta_info_get_bss(rx->sdata, hdr->addr2); + if (status->link_valid) + link_id = status->link_id; } + if (!ieee80211_rx_data_set_sta(rx, &sta->sta, link_id)) + return false; + return ieee80211_prepare_and_rx_handle(rx, skb, consume); } @@ -5036,19 +5001,15 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw, if (ieee80211_is_data(fc)) { struct sta_info *sta, *prev_sta; - u8 link_id = status->link_id; + int link_id = -1; - if (pubsta) { - rx.sta = container_of(pubsta, struct sta_info, sta); - rx.sdata = rx.sta->sdata; + if (status->link_valid) + link_id = status->link_id; - if (status->link_valid && - !ieee80211_rx_is_valid_sta_link_id(pubsta, link_id)) + if (pubsta) { + if (!ieee80211_rx_data_set_sta(&rx, pubsta, link_id)) goto out; - if (status->link_valid) - rx.link_id = status->link_id; - /* * In MLO connection, fetch the link_id using addr2 * when the driver does not pass link_id in status. @@ -5066,7 +5027,7 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw, if (!link_sta) goto out; - rx.link_id = link_sta->link_id; + ieee80211_rx_data_set_link(&rx, link_sta->link_id); } if (ieee80211_prepare_and_rx_handle(&rx, skb, true)) @@ -5082,30 +5043,27 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw, continue; } - if ((status->link_valid && - !ieee80211_rx_is_valid_sta_link_id(&prev_sta->sta, - link_id)) || - (!status->link_valid && prev_sta->sta.mlo)) + rx.sdata = prev_sta->sdata; + if (!ieee80211_rx_data_set_sta(&rx, &prev_sta->sta, + link_id)) + goto out; + + if (!status->link_valid && prev_sta->sta.mlo) continue; - rx.link_id = status->link_valid ? link_id : -1; - rx.sta = prev_sta; - rx.sdata = prev_sta->sdata; ieee80211_prepare_and_rx_handle(&rx, skb, false); prev_sta = sta; } if (prev_sta) { - if ((status->link_valid && - !ieee80211_rx_is_valid_sta_link_id(&prev_sta->sta, - link_id)) || - (!status->link_valid && prev_sta->sta.mlo)) + rx.sdata = prev_sta->sdata; + if (!ieee80211_rx_data_set_sta(&rx, &prev_sta->sta, + link_id)) goto out; - rx.link_id = status->link_valid ? link_id : -1; - rx.sta = prev_sta; - rx.sdata = prev_sta->sdata; + if (!status->link_valid && prev_sta->sta.mlo) + goto out; if (ieee80211_prepare_and_rx_handle(&rx, skb, true)) return; diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c index cebfd148bb40..04e0f132b1d9 100644 --- a/net/mac80211/sta_info.c +++ b/net/mac80211/sta_info.c @@ -140,17 +140,15 @@ static void __cleanup_single_sta(struct sta_info *sta) atomic_dec(&ps->num_sta_ps); } - if (sta->sta.txq[0]) { - for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { - struct txq_info *txqi; + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + struct txq_info *txqi; - if (!sta->sta.txq[i]) - continue; + if (!sta->sta.txq[i]) + continue; - txqi = to_txq_info(sta->sta.txq[i]); + txqi = to_txq_info(sta->sta.txq[i]); - ieee80211_txq_purge(local, txqi); - } + ieee80211_txq_purge(local, txqi); } for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { @@ -366,6 +364,9 @@ static void sta_remove_link(struct sta_info *sta, unsigned int link_id, if (unhash) link_sta_info_hash_del(sta->local, link_sta); + if (test_sta_flag(sta, WLAN_STA_INSERTED)) + ieee80211_link_sta_debugfs_remove(link_sta); + if (link_sta != &sta->deflink) alloc = container_of(link_sta, typeof(*alloc), info); @@ -425,8 +426,7 @@ void sta_info_free(struct ieee80211_local *local, struct sta_info *sta) sta_dbg(sta->sdata, "Destroyed STA %pM\n", sta->sta.addr); - if (sta->sta.txq[0]) - kfree(to_txq_info(sta->sta.txq[0])); + kfree(to_txq_info(sta->sta.txq[0])); kfree(rcu_dereference_raw(sta->sta.rates)); #ifdef CONFIG_MAC80211_MESH kfree(sta->mesh); @@ -511,6 +511,7 @@ static void sta_info_add_link(struct sta_info *sta, link_info->sta = sta; link_info->link_id = link_id; link_info->pub = link_sta; + link_info->pub->sta = &sta->sta; link_sta->link_id = link_id; rcu_assign_pointer(sta->link[link_id], link_info); rcu_assign_pointer(sta->sta.link[link_id], link_sta); @@ -527,6 +528,8 @@ __sta_info_alloc(struct ieee80211_sub_if_data *sdata, struct ieee80211_local *local = sdata->local; struct ieee80211_hw *hw = &local->hw; struct sta_info *sta; + void *txq_data; + int size; int i; sta = kzalloc(sizeof(*sta) + hw->sta_data_size, gfp); @@ -596,21 +599,18 @@ __sta_info_alloc(struct ieee80211_sub_if_data *sdata, sta->last_connected = ktime_get_seconds(); - if (local->ops->wake_tx_queue) { - void *txq_data; - int size = sizeof(struct txq_info) + - ALIGN(hw->txq_data_size, sizeof(void *)); + size = sizeof(struct txq_info) + + ALIGN(hw->txq_data_size, sizeof(void *)); - txq_data = kcalloc(ARRAY_SIZE(sta->sta.txq), size, gfp); - if (!txq_data) - goto free; + txq_data = kcalloc(ARRAY_SIZE(sta->sta.txq), size, gfp); + if (!txq_data) + goto free; - for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { - struct txq_info *txq = txq_data + i * size; + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + struct txq_info *txq = txq_data + i * size; - /* might not do anything for the bufferable MMPDU TXQ */ - ieee80211_txq_init(sdata, sta, txq, i); - } + /* might not do anything for the (bufferable) MMPDU TXQ */ + ieee80211_txq_init(sdata, sta, txq, i); } if (sta_prepare_rate_control(local, sta, gfp)) @@ -684,8 +684,7 @@ __sta_info_alloc(struct ieee80211_sub_if_data *sdata, return sta; free_txq: - if (sta->sta.txq[0]) - kfree(to_txq_info(sta->sta.txq[0])); + kfree(to_txq_info(sta->sta.txq[0])); free: sta_info_free_link(&sta->deflink); #ifdef CONFIG_MAC80211_MESH @@ -874,6 +873,26 @@ static int sta_info_insert_finish(struct sta_info *sta) __acquires(RCU) ieee80211_sta_debugfs_add(sta); rate_control_add_sta_debugfs(sta); + if (sta->sta.valid_links) { + int i; + + for (i = 0; i < ARRAY_SIZE(sta->link); i++) { + struct link_sta_info *link_sta; + + link_sta = rcu_dereference_protected(sta->link[i], + lockdep_is_held(&local->sta_mtx)); + + if (!link_sta) + continue; + + ieee80211_link_sta_debugfs_add(link_sta); + if (sdata->vif.active_links & BIT(i)) + ieee80211_link_sta_debugfs_drv_add(link_sta); + } + } else { + ieee80211_link_sta_debugfs_add(&sta->deflink); + ieee80211_link_sta_debugfs_drv_add(&sta->deflink); + } sinfo->generation = local->sta_generation; cfg80211_new_sta(sdata->dev, sta->sta.addr, sinfo, GFP_KERNEL); @@ -1958,9 +1977,6 @@ ieee80211_sta_ps_deliver_response(struct sta_info *sta, * TIM recalculation. */ - if (!sta->sta.txq[0]) - return; - for (tid = 0; tid < ARRAY_SIZE(sta->sta.txq); tid++) { if (!sta->sta.txq[tid] || !(driver_release_tids & BIT(tid)) || @@ -2127,22 +2143,30 @@ void ieee80211_sta_register_airtime(struct ieee80211_sta *pubsta, u8 tid, } EXPORT_SYMBOL(ieee80211_sta_register_airtime); -void ieee80211_sta_recalc_aggregates(struct ieee80211_sta *pubsta) +void __ieee80211_sta_recalc_aggregates(struct sta_info *sta, u16 active_links) { - struct sta_info *sta = container_of(pubsta, struct sta_info, sta); - struct ieee80211_link_sta *link_sta; - int link_id, i; bool first = true; + int link_id; - if (!pubsta->valid_links || !pubsta->mlo) { - pubsta->cur = &pubsta->deflink.agg; + if (!sta->sta.valid_links || !sta->sta.mlo) { + sta->sta.cur = &sta->sta.deflink.agg; return; } rcu_read_lock(); - for_each_sta_active_link(&sta->sdata->vif, pubsta, link_sta, link_id) { + for (link_id = 0; link_id < ARRAY_SIZE((sta)->link); link_id++) { + struct ieee80211_link_sta *link_sta; + int i; + + if (!(active_links & BIT(link_id))) + continue; + + link_sta = rcu_dereference(sta->sta.link[link_id]); + if (!link_sta) + continue; + if (first) { - sta->cur = pubsta->deflink.agg; + sta->cur = sta->sta.deflink.agg; first = false; continue; } @@ -2161,7 +2185,14 @@ void ieee80211_sta_recalc_aggregates(struct ieee80211_sta *pubsta) } rcu_read_unlock(); - pubsta->cur = &sta->cur; + sta->sta.cur = &sta->cur; +} + +void ieee80211_sta_recalc_aggregates(struct ieee80211_sta *pubsta) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + __ieee80211_sta_recalc_aggregates(sta, sta->sdata->vif.active_links); } EXPORT_SYMBOL(ieee80211_sta_recalc_aggregates); @@ -2396,9 +2427,9 @@ static inline u64 sta_get_tidstats_msdu(struct ieee80211_sta_rx_stats *rxstats, u64 value; do { - start = u64_stats_fetch_begin_irq(&rxstats->syncp); + start = u64_stats_fetch_begin(&rxstats->syncp); value = rxstats->msdu[tid]; - } while (u64_stats_fetch_retry_irq(&rxstats->syncp, start)); + } while (u64_stats_fetch_retry(&rxstats->syncp, start)); return value; } @@ -2445,7 +2476,7 @@ static void sta_set_tidstats(struct sta_info *sta, tidstats->tx_msdu_failed = sta->deflink.status_stats.msdu_failed[tid]; } - if (local->ops->wake_tx_queue && tid < IEEE80211_NUM_TIDS) { + if (tid < IEEE80211_NUM_TIDS) { spin_lock_bh(&local->fq.lock); rcu_read_lock(); @@ -2464,9 +2495,9 @@ static inline u64 sta_get_stats_bytes(struct ieee80211_sta_rx_stats *rxstats) u64 value; do { - start = u64_stats_fetch_begin_irq(&rxstats->syncp); + start = u64_stats_fetch_begin(&rxstats->syncp); value = rxstats->bytes; - } while (u64_stats_fetch_retry_irq(&rxstats->syncp, start)); + } while (u64_stats_fetch_retry(&rxstats->syncp, start)); return value; } @@ -2773,9 +2804,6 @@ unsigned long ieee80211_sta_last_active(struct sta_info *sta) static void sta_update_codel_params(struct sta_info *sta, u32 thr) { - if (!sta->sdata->local->ops->wake_tx_queue) - return; - if (thr && thr < STA_SLOW_THRESHOLD * sta->local->num_sta) { sta->cparams.target = MS2TIME(50); sta->cparams.interval = MS2TIME(300); @@ -2823,6 +2851,8 @@ int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id) sta_info_add_link(sta, link_id, &alloc->info, &alloc->sta); + ieee80211_link_sta_debugfs_add(&alloc->info); + return 0; } diff --git a/net/mac80211/sta_info.h b/net/mac80211/sta_info.h index 2517ea714dc4..69820b551668 100644 --- a/net/mac80211/sta_info.h +++ b/net/mac80211/sta_info.h @@ -513,6 +513,7 @@ struct ieee80211_fragment_cache { * @status_stats.avg_ack_signal: average ACK signal * @cur_max_bandwidth: maximum bandwidth to use for TX to the station, * taken from HT/VHT capabilities or VHT operating mode notification + * @debugfs_dir: debug filesystem directory dentry * @pub: public (driver visible) link STA data * TODO Move other link params from sta_info as required for MLD operation */ @@ -560,6 +561,10 @@ struct link_sta_info { enum ieee80211_sta_rx_bandwidth cur_max_bandwidth; +#ifdef CONFIG_MAC80211_DEBUGFS + struct dentry *debugfs_dir; +#endif + struct ieee80211_link_sta *pub; }; @@ -922,6 +927,8 @@ void ieee80211_sta_set_max_amsdu_subframes(struct sta_info *sta, const u8 *ext_capab, unsigned int ext_capab_len); +void __ieee80211_sta_recalc_aggregates(struct sta_info *sta, u16 active_links); + enum sta_stats_type { STA_STATS_RATE_TYPE_INVALID = 0, STA_STATS_RATE_TYPE_LEGACY, diff --git a/net/mac80211/tdls.c b/net/mac80211/tdls.c index f4b4d25eef95..b255f3b5bf01 100644 --- a/net/mac80211/tdls.c +++ b/net/mac80211/tdls.c @@ -1016,7 +1016,6 @@ ieee80211_tdls_prep_mgmt_packet(struct wiphy *wiphy, struct net_device *dev, skb->priority = 256 + 5; break; } - skb_set_queue_mapping(skb, ieee80211_select_queue(sdata, skb)); /* * Set the WLAN_TDLS_TEARDOWN flag to indicate a teardown in progress. diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c index 874f2a4d831d..defe97a31724 100644 --- a/net/mac80211/tx.c +++ b/net/mac80211/tx.c @@ -1129,7 +1129,6 @@ static bool ieee80211_tx_prep_agg(struct ieee80211_tx_data *tx, struct sk_buff *purge_skb = NULL; if (test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) { - info->flags |= IEEE80211_TX_CTL_AMPDU; reset_agg_timer = true; } else if (test_bit(HT_AGG_STATE_WANT_START, &tid_tx->state)) { /* @@ -1161,7 +1160,6 @@ static bool ieee80211_tx_prep_agg(struct ieee80211_tx_data *tx, if (!tid_tx) { /* do nothing, let packet pass through */ } else if (test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) { - info->flags |= IEEE80211_TX_CTL_AMPDU; reset_agg_timer = true; } else { queued = true; @@ -1343,7 +1341,7 @@ static struct txq_info *ieee80211_get_txq(struct ieee80211_local *local, return NULL; txq = sta->sta.txq[tid]; - } else if (vif) { + } else { txq = vif->txq; } @@ -1355,7 +1353,11 @@ static struct txq_info *ieee80211_get_txq(struct ieee80211_local *local, static void ieee80211_set_skb_enqueue_time(struct sk_buff *skb) { - IEEE80211_SKB_CB(skb)->control.enqueue_time = codel_get_time(); + struct sk_buff *next; + codel_time_t now = codel_get_time(); + + skb_list_walk_safe(skb, skb, next) + IEEE80211_SKB_CB(skb)->control.enqueue_time = now; } static u32 codel_skb_len_func(const struct sk_buff *skb) @@ -1599,9 +1601,6 @@ int ieee80211_txq_setup_flows(struct ieee80211_local *local) bool supp_vht = false; enum nl80211_band band; - if (!local->ops->wake_tx_queue) - return 0; - ret = fq_init(fq, 4096); if (ret) return ret; @@ -1649,9 +1648,6 @@ void ieee80211_txq_teardown_flows(struct ieee80211_local *local) { struct fq *fq = &local->fq; - if (!local->ops->wake_tx_queue) - return; - kfree(local->cvars); local->cvars = NULL; @@ -1668,8 +1664,7 @@ static bool ieee80211_queue_skb(struct ieee80211_local *local, struct ieee80211_vif *vif; struct txq_info *txqi; - if (!local->ops->wake_tx_queue || - sdata->vif.type == NL80211_IFTYPE_MONITOR) + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) return false; if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) @@ -2973,7 +2968,7 @@ static struct sk_buff *ieee80211_build_hdr(struct ieee80211_sub_if_data *sdata, if (pre_conf_link_id != link_id && link_id != IEEE80211_LINK_UNSPECIFIED) { -#ifdef CPTCFG_MAC80211_VERBOSE_DEBUG +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG net_info_ratelimited("%s: dropped frame to %pM with bad link ID request (%d vs. %d)\n", sdata->name, hdr.addr1, pre_conf_link_id, link_id); @@ -3585,55 +3580,79 @@ ieee80211_xmit_fast_finish(struct ieee80211_sub_if_data *sdata, return TX_CONTINUE; } -static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, - struct sta_info *sta, - struct ieee80211_fast_tx *fast_tx, - struct sk_buff *skb) +static netdev_features_t +ieee80211_sdata_netdev_features(struct ieee80211_sub_if_data *sdata) { - struct ieee80211_local *local = sdata->local; - u16 ethertype = (skb->data[12] << 8) | skb->data[13]; - int extra_head = fast_tx->hdr_len - (ETH_HLEN - 2); - int hw_headroom = sdata->local->hw.extra_tx_headroom; - struct ethhdr eth; - struct ieee80211_tx_info *info; - struct ieee80211_hdr *hdr = (void *)fast_tx->hdr; - struct ieee80211_tx_data tx; - ieee80211_tx_result r; - struct tid_ampdu_tx *tid_tx = NULL; - u8 tid = IEEE80211_NUM_TIDS; + if (sdata->vif.type != NL80211_IFTYPE_AP_VLAN) + return sdata->vif.netdev_features; - /* control port protocol needs a lot of special handling */ - if (cpu_to_be16(ethertype) == sdata->control_port_protocol) - return false; + if (!sdata->bss) + return 0; - /* only RFC 1042 SNAP */ - if (ethertype < ETH_P_802_3_MIN) - return false; + sdata = container_of(sdata->bss, struct ieee80211_sub_if_data, u.ap); + return sdata->vif.netdev_features; +} - /* don't handle TX status request here either */ - if (skb->sk && skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS) - return false; +static struct sk_buff * +ieee80211_tx_skb_fixup(struct sk_buff *skb, netdev_features_t features) +{ + if (skb_is_gso(skb)) { + struct sk_buff *segs; - if (hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) { - tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; - tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); - if (tid_tx) { - if (!test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) - return false; - if (tid_tx->timeout) - tid_tx->last_tx = jiffies; - } + segs = skb_gso_segment(skb, features); + if (!segs) + return skb; + if (IS_ERR(segs)) + goto free; + + consume_skb(skb); + return segs; } - /* after this point (skb is modified) we cannot return false */ + if (skb_needs_linearize(skb, features) && __skb_linearize(skb)) + goto free; + + if (skb->ip_summed == CHECKSUM_PARTIAL) { + int ofs = skb_checksum_start_offset(skb); + + if (skb->encapsulation) + skb_set_inner_transport_header(skb, ofs); + else + skb_set_transport_header(skb, ofs); + + if (skb_csum_hwoffload_help(skb, features)) + goto free; + } + + skb_mark_not_on_list(skb); + return skb; + +free: + kfree_skb(skb); + return NULL; +} + +static void __ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_fast_tx *fast_tx, + struct sk_buff *skb, u8 tid, bool ampdu) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_hdr *hdr = (void *)fast_tx->hdr; + struct ieee80211_tx_info *info; + struct ieee80211_tx_data tx; + ieee80211_tx_result r; + int hw_headroom = sdata->local->hw.extra_tx_headroom; + int extra_head = fast_tx->hdr_len - (ETH_HLEN - 2); + struct ethhdr eth; skb = skb_share_check(skb, GFP_ATOMIC); if (unlikely(!skb)) - return true; + return; if ((hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) && ieee80211_amsdu_aggregate(sdata, sta, fast_tx, skb)) - return true; + return; /* will not be crypto-handled beyond what we do here, so use false * as the may-encrypt argument for the resize to not account for @@ -3642,10 +3661,8 @@ static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, if (unlikely(ieee80211_skb_resize(sdata, skb, max_t(int, extra_head + hw_headroom - skb_headroom(skb), 0), - ENCRYPT_NO))) { - kfree_skb(skb); - return true; - } + ENCRYPT_NO))) + goto free; memcpy(ð, skb->data, ETH_HLEN - 2); hdr = skb_push(skb, extra_head); @@ -3658,8 +3675,7 @@ static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, info->band = fast_tx->band; info->control.vif = &sdata->vif; info->flags = IEEE80211_TX_CTL_FIRST_FRAGMENT | - IEEE80211_TX_CTL_DONTFRAG | - (tid_tx ? IEEE80211_TX_CTL_AMPDU : 0); + IEEE80211_TX_CTL_DONTFRAG; info->control.flags = IEEE80211_TX_CTRL_FAST_XMIT | u32_encode_bits(IEEE80211_LINK_UNSPECIFIED, IEEE80211_TX_CTRL_MLO_LINK); @@ -3683,16 +3699,14 @@ static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, tx.key = fast_tx->key; if (ieee80211_queue_skb(local, sdata, sta, skb)) - return true; + return; tx.skb = skb; r = ieee80211_xmit_fast_finish(sdata, sta, fast_tx->pn_offs, fast_tx->key, &tx); tx.skb = NULL; - if (r == TX_DROP) { - kfree_skb(skb); - return true; - } + if (r == TX_DROP) + goto free; if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) sdata = container_of(sdata->bss, @@ -3700,6 +3714,56 @@ static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, __skb_queue_tail(&tx.skbs, skb); ieee80211_tx_frags(local, &sdata->vif, sta, &tx.skbs, false); + return; + +free: + kfree_skb(skb); +} + +static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_fast_tx *fast_tx, + struct sk_buff *skb) +{ + u16 ethertype = (skb->data[12] << 8) | skb->data[13]; + struct ieee80211_hdr *hdr = (void *)fast_tx->hdr; + struct tid_ampdu_tx *tid_tx = NULL; + struct sk_buff *next; + u8 tid = IEEE80211_NUM_TIDS; + + /* control port protocol needs a lot of special handling */ + if (cpu_to_be16(ethertype) == sdata->control_port_protocol) + return false; + + /* only RFC 1042 SNAP */ + if (ethertype < ETH_P_802_3_MIN) + return false; + + /* don't handle TX status request here either */ + if (skb->sk && skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS) + return false; + + if (hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) { + tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); + if (tid_tx) { + if (!test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) + return false; + if (tid_tx->timeout) + tid_tx->last_tx = jiffies; + } + } + + /* after this point (skb is modified) we cannot return false */ + skb = ieee80211_tx_skb_fixup(skb, ieee80211_sdata_netdev_features(sdata)); + if (!skb) + return true; + + skb_list_walk_safe(skb, skb, next) { + skb_mark_not_on_list(skb); + __ieee80211_xmit_fast(sdata, sta, fast_tx, skb, tid, tid_tx); + } + return true; } @@ -3716,6 +3780,8 @@ struct sk_buff *ieee80211_tx_dequeue(struct ieee80211_hw *hw, struct ieee80211_tx_data tx; ieee80211_tx_result r; struct ieee80211_vif *vif = txq->vif; + int q = vif->hw_queue[txq->ac]; + bool q_stopped; WARN_ON_ONCE(softirq_count() == 0); @@ -3723,17 +3789,18 @@ struct sk_buff *ieee80211_tx_dequeue(struct ieee80211_hw *hw, return NULL; begin: - spin_lock_bh(&fq->lock); - - if (test_bit(IEEE80211_TXQ_STOP, &txqi->flags) || - test_bit(IEEE80211_TXQ_STOP_NETIF_TX, &txqi->flags)) - goto out; + spin_lock(&local->queue_stop_reason_lock); + q_stopped = local->queue_stop_reasons[q]; + spin_unlock(&local->queue_stop_reason_lock); - if (vif->txqs_stopped[txq->ac]) { - set_bit(IEEE80211_TXQ_STOP_NETIF_TX, &txqi->flags); - goto out; + if (unlikely(q_stopped)) { + /* mark for waking later */ + set_bit(IEEE80211_TXQ_DIRTY, &txqi->flags); + return NULL; } + spin_lock_bh(&fq->lock); + /* Make sure fragments stay together. */ skb = __skb_dequeue(&txqi->frags); if (unlikely(skb)) { @@ -3743,6 +3810,9 @@ begin: IEEE80211_SKB_CB(skb)->control.flags &= ~IEEE80211_TX_INTCFL_NEED_TXPROCESSING; } else { + if (unlikely(test_bit(IEEE80211_TXQ_STOP, &txqi->flags))) + goto out; + skb = fq_tin_dequeue(fq, tin, fq_tin_dequeue_func); } @@ -3793,9 +3863,8 @@ begin: } if (test_bit(IEEE80211_TXQ_AMPDU, &txqi->flags)) - info->flags |= IEEE80211_TX_CTL_AMPDU; - else - info->flags &= ~IEEE80211_TX_CTL_AMPDU; + info->flags |= (IEEE80211_TX_CTL_AMPDU | + IEEE80211_TX_CTL_DONTFRAG); if (info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP) { if (!ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) { @@ -4184,12 +4253,7 @@ void __ieee80211_subif_start_xmit(struct sk_buff *skb, if (IS_ERR(sta)) sta = NULL; - if (local->ops->wake_tx_queue) { - u16 queue = __ieee80211_select_queue(sdata, sta, skb); - skb_set_queue_mapping(skb, queue); - skb_get_hash(skb); - } - + skb_set_queue_mapping(skb, ieee80211_select_queue(sdata, sta, skb)); ieee80211_aggr_check(sdata, sta, skb); sk_pacing_shift_update(skb->sk, sdata->local->hw.tx_sk_pacing_shift); @@ -4204,31 +4268,14 @@ void __ieee80211_subif_start_xmit(struct sk_buff *skb, goto out; } - if (skb_is_gso(skb)) { - struct sk_buff *segs; - - segs = skb_gso_segment(skb, 0); - if (IS_ERR(segs)) { - goto out_free; - } else if (segs) { - consume_skb(skb); - skb = segs; - } - } else { - /* we cannot process non-linear frames on this path */ - if (skb_linearize(skb)) - goto out_free; - - /* the frame could be fragmented, software-encrypted, and other - * things so we cannot really handle checksum offload with it - - * fix it up in software before we handle anything else. - */ - if (skb->ip_summed == CHECKSUM_PARTIAL) { - skb_set_transport_header(skb, - skb_checksum_start_offset(skb)); - if (skb_checksum_help(skb)) - goto out_free; - } + /* the frame could be fragmented, software-encrypted, and other + * things so we cannot really handle checksum or GSO offload. + * fix it up in software before we handle anything else. + */ + skb = ieee80211_tx_skb_fixup(skb, 0); + if (!skb) { + len = 0; + goto out; } skb_list_walk_safe(skb, skb, next) { @@ -4446,9 +4493,11 @@ normal: return NETDEV_TX_OK; } -static bool ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, - struct sk_buff *skb, struct sta_info *sta, - bool txpending) + + +static bool __ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, struct sta_info *sta, + bool txpending) { struct ieee80211_local *local = sdata->local; struct ieee80211_tx_control control = {}; @@ -4457,14 +4506,6 @@ static bool ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, unsigned long flags; int q = info->hw_queue; - if (sta) - sk_pacing_shift_update(skb->sk, local->hw.tx_sk_pacing_shift); - - ieee80211_tpt_led_trig_tx(local, skb->len); - - if (ieee80211_queue_skb(local, sdata, sta, skb)) - return true; - spin_lock_irqsave(&local->queue_stop_reason_lock, flags); if (local->queue_stop_reasons[q] || @@ -4491,6 +4532,26 @@ static bool ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, return true; } +static bool ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, struct sta_info *sta, + bool txpending) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *next; + bool ret = true; + + if (ieee80211_queue_skb(local, sdata, sta, skb)) + return true; + + skb_list_walk_safe(skb, skb, next) { + skb_mark_not_on_list(skb); + if (!__ieee80211_tx_8023(sdata, skb, sta, txpending)) + ret = false; + } + + return ret; +} + static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, struct net_device *dev, struct sta_info *sta, struct ieee80211_key *key, struct sk_buff *skb) @@ -4498,13 +4559,13 @@ static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, struct ieee80211_tx_info *info; struct ieee80211_local *local = sdata->local; struct tid_ampdu_tx *tid_tx; + struct sk_buff *seg, *next; + unsigned int skbs = 0, len = 0; + u16 queue; u8 tid; - if (local->ops->wake_tx_queue) { - u16 queue = __ieee80211_select_queue(sdata, sta, skb); - skb_set_queue_mapping(skb, queue); - skb_get_hash(skb); - } + queue = ieee80211_select_queue(sdata, sta, skb); + skb_set_queue_mapping(skb, queue); if (unlikely(test_bit(SCAN_SW_SCANNING, &local->scanning)) && test_bit(SDATA_STATE_OFFCHANNEL, &sdata->state)) @@ -4514,9 +4575,6 @@ static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, if (unlikely(!skb)) return; - info = IEEE80211_SKB_CB(skb); - memset(info, 0, sizeof(*info)); - ieee80211_aggr_check(sdata, sta, skb); tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; @@ -4530,22 +4588,18 @@ static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, return; } - info->flags |= IEEE80211_TX_CTL_AMPDU; if (tid_tx->timeout) tid_tx->last_tx = jiffies; } - if (unlikely(skb->sk && - skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS)) - info->ack_frame_id = ieee80211_store_ack_skb(local, skb, - &info->flags, NULL); - - info->hw_queue = sdata->vif.hw_queue[skb_get_queue_mapping(skb)]; + skb = ieee80211_tx_skb_fixup(skb, ieee80211_sdata_netdev_features(sdata)); + if (!skb) + return; - dev_sw_netstats_tx_add(dev, 1, skb->len); + info = IEEE80211_SKB_CB(skb); + memset(info, 0, sizeof(*info)); - sta->deflink.tx_stats.bytes[skb_get_queue_mapping(skb)] += skb->len; - sta->deflink.tx_stats.packets[skb_get_queue_mapping(skb)]++; + info->hw_queue = sdata->vif.hw_queue[queue]; if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) sdata = container_of(sdata->bss, @@ -4557,6 +4611,24 @@ static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, if (key) info->control.hw_key = &key->conf; + skb_list_walk_safe(skb, seg, next) { + skbs++; + len += seg->len; + if (seg != skb) + memcpy(IEEE80211_SKB_CB(seg), info, sizeof(*info)); + } + + if (unlikely(skb->sk && + skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS)) + info->ack_frame_id = ieee80211_store_ack_skb(local, skb, + &info->flags, NULL); + + dev_sw_netstats_tx_add(dev, skbs, len); + sta->deflink.tx_stats.packets[queue] += skbs; + sta->deflink.tx_stats.bytes[queue] += len; + + ieee80211_tpt_led_trig_tx(local, len); + ieee80211_tx_8023(sdata, skb, sta, false); return; @@ -4598,6 +4670,7 @@ netdev_tx_t ieee80211_subif_start_xmit_8023(struct sk_buff *skb, key->conf.cipher == WLAN_CIPHER_SUITE_TKIP)) goto skip_offload; + sk_pacing_shift_update(skb->sk, sdata->local->hw.tx_sk_pacing_shift); ieee80211_8023_xmit(sdata, dev, sta, key, skb); goto out; @@ -4758,9 +4831,6 @@ void ieee80211_tx_pending(struct tasklet_struct *t) if (!txok) break; } - - if (skb_queue_empty(&local->pending[i])) - ieee80211_propagate_queue_wake(local, i); } spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); @@ -4793,9 +4863,9 @@ static void __ieee80211_beacon_add_tim(struct ieee80211_sub_if_data *sdata, ps->dtim_count--; } - tim = pos = skb_put(skb, 6); + tim = pos = skb_put(skb, 5); *pos++ = WLAN_EID_TIM; - *pos++ = 4; + *pos++ = 3; *pos++ = ps->dtim_count; *pos++ = link_conf->dtim_period; @@ -4826,13 +4896,17 @@ static void __ieee80211_beacon_add_tim(struct ieee80211_sub_if_data *sdata, /* Bitmap control */ *pos++ = n1 | aid0; /* Part Virt Bitmap */ - skb_put(skb, n2 - n1); - memcpy(pos, ps->tim + n1, n2 - n1 + 1); + skb_put_data(skb, ps->tim + n1, n2 - n1 + 1); tim[1] = n2 - n1 + 4; } else { *pos++ = aid0; /* Bitmap control */ - *pos++ = 0; /* Part Virt Bitmap */ + + if (ieee80211_get_link_sband(link)->band != NL80211_BAND_S1GHZ) { + tim[1] = 4; + /* Part Virt Bitmap */ + skb_put_u8(skb, 0); + } } } @@ -5953,10 +6027,9 @@ int ieee80211_tx_control_port(struct wiphy *wiphy, struct net_device *dev, } if (!IS_ERR(sta)) { - u16 queue = __ieee80211_select_queue(sdata, sta, skb); + u16 queue = ieee80211_select_queue(sdata, sta, skb); skb_set_queue_mapping(skb, queue); - skb_get_hash(skb); /* * for MLO STA, the SA should be the AP MLD address, but diff --git a/net/mac80211/util.c b/net/mac80211/util.c index b512cb37aafb..261ac667887f 100644 --- a/net/mac80211/util.c +++ b/net/mac80211/util.c @@ -288,6 +288,42 @@ __le16 ieee80211_ctstoself_duration(struct ieee80211_hw *hw, } EXPORT_SYMBOL(ieee80211_ctstoself_duration); +static void wake_tx_push_queue(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_txq *queue) +{ + struct ieee80211_tx_control control = { + .sta = queue->sta, + }; + struct sk_buff *skb; + + while (1) { + skb = ieee80211_tx_dequeue(&local->hw, queue); + if (!skb) + break; + + drv_tx(local, &control, skb); + } +} + +/* wake_tx_queue handler for driver not implementing a custom one*/ +void ieee80211_handle_wake_tx_queue(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata = vif_to_sdata(txq->vif); + struct ieee80211_txq *queue; + + /* Use ieee80211_next_txq() for airtime fairness accounting */ + ieee80211_txq_schedule_start(hw, txq->ac); + while ((queue = ieee80211_next_txq(hw, txq->ac))) { + wake_tx_push_queue(local, sdata, queue); + ieee80211_return_txq(hw, queue, false); + } + ieee80211_txq_schedule_end(hw, txq->ac); +} +EXPORT_SYMBOL(ieee80211_handle_wake_tx_queue); + static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) { struct ieee80211_local *local = sdata->local; @@ -301,8 +337,6 @@ static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) local_bh_disable(); spin_lock(&fq->lock); - sdata->vif.txqs_stopped[ac] = false; - if (!test_bit(SDATA_STATE_RUNNING, &sdata->state)) goto out; @@ -324,7 +358,7 @@ static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) if (ac != txq->ac) continue; - if (!test_and_clear_bit(IEEE80211_TXQ_STOP_NETIF_TX, + if (!test_and_clear_bit(IEEE80211_TXQ_DIRTY, &txqi->flags)) continue; @@ -339,7 +373,7 @@ static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) txqi = to_txq_info(vif->txq); - if (!test_and_clear_bit(IEEE80211_TXQ_STOP_NETIF_TX, &txqi->flags) || + if (!test_and_clear_bit(IEEE80211_TXQ_DIRTY, &txqi->flags) || (ps && atomic_read(&ps->num_sta_ps)) || ac != vif->txq->ac) goto out; @@ -400,39 +434,6 @@ void ieee80211_wake_txqs(struct tasklet_struct *t) spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); } -void ieee80211_propagate_queue_wake(struct ieee80211_local *local, int queue) -{ - struct ieee80211_sub_if_data *sdata; - int n_acs = IEEE80211_NUM_ACS; - - if (local->ops->wake_tx_queue) - return; - - if (local->hw.queues < IEEE80211_NUM_ACS) - n_acs = 1; - - list_for_each_entry_rcu(sdata, &local->interfaces, list) { - int ac; - - if (!sdata->dev) - continue; - - if (sdata->vif.cab_queue != IEEE80211_INVAL_HW_QUEUE && - local->queue_stop_reasons[sdata->vif.cab_queue] != 0) - continue; - - for (ac = 0; ac < n_acs; ac++) { - int ac_queue = sdata->vif.hw_queue[ac]; - - if (ac_queue == queue || - (sdata->vif.cab_queue == queue && - local->queue_stop_reasons[ac_queue] == 0 && - skb_queue_empty(&local->pending[ac_queue]))) - netif_wake_subqueue(sdata->dev, ac); - } - } -} - static void __ieee80211_wake_queue(struct ieee80211_hw *hw, int queue, enum queue_stop_reason reason, bool refcounted, @@ -463,11 +464,7 @@ static void __ieee80211_wake_queue(struct ieee80211_hw *hw, int queue, /* someone still has this queue stopped */ return; - if (skb_queue_empty(&local->pending[queue])) { - rcu_read_lock(); - ieee80211_propagate_queue_wake(local, queue); - rcu_read_unlock(); - } else + if (!skb_queue_empty(&local->pending[queue])) tasklet_schedule(&local->tx_pending_tasklet); /* @@ -477,12 +474,10 @@ static void __ieee80211_wake_queue(struct ieee80211_hw *hw, int queue, * release someone's lock, but it is fine because all the callers of * __ieee80211_wake_queue call it right before releasing the lock. */ - if (local->ops->wake_tx_queue) { - if (reason == IEEE80211_QUEUE_STOP_REASON_DRIVER) - tasklet_schedule(&local->wake_txqs_tasklet); - else - _ieee80211_wake_txqs(local, flags); - } + if (reason == IEEE80211_QUEUE_STOP_REASON_DRIVER) + tasklet_schedule(&local->wake_txqs_tasklet); + else + _ieee80211_wake_txqs(local, flags); } void ieee80211_wake_queue_by_reason(struct ieee80211_hw *hw, int queue, @@ -510,8 +505,6 @@ static void __ieee80211_stop_queue(struct ieee80211_hw *hw, int queue, bool refcounted) { struct ieee80211_local *local = hw_to_local(hw); - struct ieee80211_sub_if_data *sdata; - int n_acs = IEEE80211_NUM_ACS; trace_stop_queue(local, queue, reason); @@ -523,33 +516,7 @@ static void __ieee80211_stop_queue(struct ieee80211_hw *hw, int queue, else local->q_stop_reasons[queue][reason]++; - if (__test_and_set_bit(reason, &local->queue_stop_reasons[queue])) - return; - - if (local->hw.queues < IEEE80211_NUM_ACS) - n_acs = 1; - - rcu_read_lock(); - list_for_each_entry_rcu(sdata, &local->interfaces, list) { - int ac; - - if (!sdata->dev) - continue; - - for (ac = 0; ac < n_acs; ac++) { - if (sdata->vif.hw_queue[ac] == queue || - sdata->vif.cab_queue == queue) { - if (!local->ops->wake_tx_queue) { - netif_stop_subqueue(sdata->dev, ac); - continue; - } - spin_lock(&local->fq.lock); - sdata->vif.txqs_stopped[ac] = true; - spin_unlock(&local->fq.lock); - } - } - } - rcu_read_unlock(); + set_bit(reason, &local->queue_stop_reasons[queue]); } void ieee80211_stop_queue_by_reason(struct ieee80211_hw *hw, int queue, @@ -1026,8 +993,10 @@ ieee80211_parse_extension_element(u32 *crc, elems->eht_operation = data; break; case WLAN_EID_EXT_EHT_MULTI_LINK: - if (ieee80211_mle_size_ok(data, len)) + if (ieee80211_mle_size_ok(data, len)) { elems->multi_link = (void *)data; + elems->multi_link_len = len; + } break; } } @@ -1499,6 +1468,145 @@ static size_t ieee802_11_find_bssid_profile(const u8 *start, size_t len, return found ? profile_len : 0; } +static void ieee80211_defragment_element(struct ieee802_11_elems *elems, + void **elem_ptr, size_t *len, + size_t total_len, u8 frag_id) +{ + u8 *data = *elem_ptr, *pos, *start; + const struct element *elem; + + /* + * Since 'data' points to the data of the element, not the element + * itself, allow 254 in case it was an extended element where the + * extended ID isn't part of the data we see here and thus not part of + * 'len' either. + */ + if (!data || (*len != 254 && *len != 255)) + return; + + start = elems->scratch_pos; + + if (WARN_ON(*len > (elems->scratch + elems->scratch_len - + elems->scratch_pos))) + return; + + memcpy(elems->scratch_pos, data, *len); + elems->scratch_pos += *len; + + pos = data + *len; + total_len -= *len; + for_each_element(elem, pos, total_len) { + if (elem->id != frag_id) + break; + + if (WARN_ON(elem->datalen > + (elems->scratch + elems->scratch_len - + elems->scratch_pos))) + return; + + memcpy(elems->scratch_pos, elem->data, elem->datalen); + elems->scratch_pos += elem->datalen; + + *len += elem->datalen; + } + + *elem_ptr = start; +} + +static void ieee80211_mle_get_sta_prof(struct ieee802_11_elems *elems, + u8 link_id) +{ + const struct ieee80211_multi_link_elem *ml = elems->multi_link; + size_t ml_len = elems->multi_link_len; + const struct element *sub; + + if (!ml || !ml_len) + return; + + if (le16_get_bits(ml->control, IEEE80211_ML_CONTROL_TYPE) != + IEEE80211_ML_CONTROL_TYPE_BASIC) + return; + + for_each_mle_subelement(sub, (u8 *)ml, ml_len) { + struct ieee80211_mle_per_sta_profile *prof = (void *)sub->data; + u16 control; + + if (sub->id != IEEE80211_MLE_SUBELEM_PER_STA_PROFILE) + continue; + + if (!ieee80211_mle_sta_prof_size_ok(sub->data, sub->datalen)) + return; + + control = le16_to_cpu(prof->control); + + if (link_id != u16_get_bits(control, + IEEE80211_MLE_STA_CONTROL_LINK_ID)) + continue; + + if (!(control & IEEE80211_MLE_STA_CONTROL_COMPLETE_PROFILE)) + return; + + elems->prof = prof; + elems->sta_prof_len = sub->datalen; + + /* the sub element can be fragmented */ + ieee80211_defragment_element(elems, (void **)&elems->prof, + &elems->sta_prof_len, + ml_len - (sub->data - (u8 *)ml), + IEEE80211_MLE_SUBELEM_FRAGMENT); + return; + } +} + +static void ieee80211_mle_parse_link(struct ieee802_11_elems *elems, + struct ieee80211_elems_parse_params *params) +{ + struct ieee80211_mle_per_sta_profile *prof; + struct ieee80211_elems_parse_params sub = { + .action = params->action, + .from_ap = params->from_ap, + .link_id = -1, + }; + const struct element *non_inherit = NULL; + const u8 *end; + + if (params->link_id == -1) + return; + + ieee80211_defragment_element(elems, (void **)&elems->multi_link, + &elems->multi_link_len, + elems->total_len - ((u8 *)elems->multi_link - + elems->ie_start), + WLAN_EID_FRAGMENT); + + ieee80211_mle_get_sta_prof(elems, params->link_id); + prof = elems->prof; + + if (!prof) + return; + + /* check if we have the 4 bytes for the fixed part in assoc response */ + if (elems->sta_prof_len < sizeof(*prof) + prof->sta_info_len - 1 + 4) { + elems->prof = NULL; + elems->sta_prof_len = 0; + return; + } + + /* + * Skip the capability information and the status code that are expected + * as part of the station profile in association response frames. Note + * the -1 is because the 'sta_info_len' is accounted to as part of the + * per-STA profile, but not part of the 'u8 variable[]' portion. + */ + sub.start = prof->variable + prof->sta_info_len - 1 + 4; + end = (const u8 *)prof + elems->sta_prof_len; + sub.len = end - sub.start; + + non_inherit = cfg80211_find_ext_elem(WLAN_EID_EXT_NON_INHERITANCE, + sub.start, sub.len); + _ieee802_11_parse_elems_full(&sub, elems, non_inherit); +} + struct ieee802_11_elems * ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params) { @@ -1506,7 +1614,7 @@ ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params) const struct element *non_inherit = NULL; u8 *nontransmitted_profile; int nontransmitted_profile_len = 0; - size_t scratch_len = params->len; + size_t scratch_len = params->scratch_len ?: 3 * params->len; elems = kzalloc(sizeof(*elems) + scratch_len, GFP_ATOMIC); if (!elems) @@ -1541,6 +1649,8 @@ ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params) _ieee802_11_parse_elems_full(&sub, elems, NULL); } + ieee80211_mle_parse_link(elems, params); + if (elems->tim && !elems->parse_error) { const struct ieee80211_tim_ie *tim_ie = elems->tim; diff --git a/net/mac80211/wme.c b/net/mac80211/wme.c index ecc1de2e68a5..a12c63638680 100644 --- a/net/mac80211/wme.c +++ b/net/mac80211/wme.c @@ -122,6 +122,9 @@ u16 ieee80211_select_queue_80211(struct ieee80211_sub_if_data *sdata, struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); u8 *p; + /* Ensure hash is set prior to potential SW encryption */ + skb_get_hash(skb); + if ((info->control.flags & IEEE80211_TX_CTRL_DONT_REORDER) || local->hw.queues < IEEE80211_NUM_ACS) return 0; @@ -141,12 +144,15 @@ u16 ieee80211_select_queue_80211(struct ieee80211_sub_if_data *sdata, return ieee80211_downgrade_queue(sdata, NULL, skb); } -u16 __ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, - struct sta_info *sta, struct sk_buff *skb) +u16 ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) { struct mac80211_qos_map *qos_map; bool qos; + /* Ensure hash is set prior to potential SW encryption */ + skb_get_hash(skb); + /* all mesh/ocb stations are required to support WME */ if (sta && (sdata->vif.type == NL80211_IFTYPE_MESH_POINT || sdata->vif.type == NL80211_IFTYPE_OCB)) @@ -176,59 +182,6 @@ u16 __ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, return ieee80211_downgrade_queue(sdata, sta, skb); } - -/* Indicate which queue to use. */ -u16 ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, - struct sk_buff *skb) -{ - struct ieee80211_local *local = sdata->local; - struct sta_info *sta = NULL; - const u8 *ra = NULL; - u16 ret; - - /* when using iTXQ, we can do this later */ - if (local->ops->wake_tx_queue) - return 0; - - if (local->hw.queues < IEEE80211_NUM_ACS || skb->len < 6) { - skb->priority = 0; /* required for correct WPA/11i MIC */ - return 0; - } - - rcu_read_lock(); - switch (sdata->vif.type) { - case NL80211_IFTYPE_AP_VLAN: - sta = rcu_dereference(sdata->u.vlan.sta); - if (sta) - break; - fallthrough; - case NL80211_IFTYPE_AP: - ra = skb->data; - break; - case NL80211_IFTYPE_STATION: - /* might be a TDLS station */ - sta = sta_info_get(sdata, skb->data); - if (sta) - break; - - ra = sdata->deflink.u.mgd.bssid; - break; - case NL80211_IFTYPE_ADHOC: - ra = skb->data; - break; - default: - break; - } - - if (!sta && ra && !is_multicast_ether_addr(ra)) - sta = sta_info_get(sdata, ra); - - ret = __ieee80211_select_queue(sdata, sta, skb); - - rcu_read_unlock(); - return ret; -} - /** * ieee80211_set_qos_hdr - Fill in the QoS header if there is one. * diff --git a/net/mac80211/wme.h b/net/mac80211/wme.h index 2e3dec0b6087..81f0039527a9 100644 --- a/net/mac80211/wme.h +++ b/net/mac80211/wme.h @@ -13,10 +13,8 @@ u16 ieee80211_select_queue_80211(struct ieee80211_sub_if_data *sdata, struct sk_buff *skb, struct ieee80211_hdr *hdr); -u16 __ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, - struct sta_info *sta, struct sk_buff *skb); u16 ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, - struct sk_buff *skb); + struct sta_info *sta, struct sk_buff *skb); void ieee80211_set_qos_hdr(struct ieee80211_sub_if_data *sdata, struct sk_buff *skb); diff --git a/net/mac802154/cfg.c b/net/mac802154/cfg.c index 1e4a9f74ed43..dc2d918fac68 100644 --- a/net/mac802154/cfg.c +++ b/net/mac802154/cfg.c @@ -46,7 +46,7 @@ static int ieee802154_suspend(struct wpan_phy *wpan_phy) if (!local->open_count) goto suspend; - ieee802154_stop_queue(&local->hw); + ieee802154_sync_and_hold_queue(local); synchronize_net(); /* stop hardware - this must stop RX */ @@ -67,12 +67,12 @@ static int ieee802154_resume(struct wpan_phy *wpan_phy) goto wake_up; /* restart hardware */ - ret = drv_start(local); + ret = drv_start(local, local->phy->filtering, &local->addr_filt); if (ret) return ret; wake_up: - ieee802154_wake_queue(&local->hw); + ieee802154_release_queue(local); local->suspended = false; return 0; } diff --git a/net/mac802154/driver-ops.h b/net/mac802154/driver-ops.h index d23f0db98015..a7af3f0ddb3e 100644 --- a/net/mac802154/driver-ops.h +++ b/net/mac802154/driver-ops.h @@ -24,203 +24,290 @@ drv_xmit_sync(struct ieee802154_local *local, struct sk_buff *skb) return local->ops->xmit_sync(&local->hw, skb); } -static inline int drv_start(struct ieee802154_local *local) +static inline int drv_set_pan_id(struct ieee802154_local *local, __le16 pan_id) { + struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - trace_802154_drv_start(local); - local->started = true; - smp_mb(); - ret = local->ops->start(&local->hw); + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.pan_id = pan_id; + + trace_802154_drv_set_pan_id(local, pan_id); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_PANID_CHANGED); trace_802154_drv_return_int(local, ret); return ret; } -static inline void drv_stop(struct ieee802154_local *local) +static inline int +drv_set_extended_addr(struct ieee802154_local *local, __le64 extended_addr) { - might_sleep(); + struct ieee802154_hw_addr_filt filt; + int ret; - trace_802154_drv_stop(local); - local->ops->stop(&local->hw); - trace_802154_drv_return_void(local); + might_sleep(); - /* sync away all work on the tasklet before clearing started */ - tasklet_disable(&local->tasklet); - tasklet_enable(&local->tasklet); + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } - barrier(); + filt.ieee_addr = extended_addr; - local->started = false; + trace_802154_drv_set_extended_addr(local, extended_addr); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_IEEEADDR_CHANGED); + trace_802154_drv_return_int(local, ret); + return ret; } static inline int -drv_set_channel(struct ieee802154_local *local, u8 page, u8 channel) +drv_set_short_addr(struct ieee802154_local *local, __le16 short_addr) { + struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - trace_802154_drv_set_channel(local, page, channel); - ret = local->ops->set_channel(&local->hw, page, channel); + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.short_addr = short_addr; + + trace_802154_drv_set_short_addr(local, short_addr); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_SADDR_CHANGED); trace_802154_drv_return_int(local, ret); return ret; } -static inline int drv_set_tx_power(struct ieee802154_local *local, s32 mbm) +static inline int +drv_set_pan_coord(struct ieee802154_local *local, bool is_coord) { + struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - if (!local->ops->set_txpower) { + if (!local->ops->set_hw_addr_filt) { WARN_ON(1); return -EOPNOTSUPP; } - trace_802154_drv_set_tx_power(local, mbm); - ret = local->ops->set_txpower(&local->hw, mbm); + filt.pan_coord = is_coord; + + trace_802154_drv_set_pan_coord(local, is_coord); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_PANC_CHANGED); trace_802154_drv_return_int(local, ret); return ret; } -static inline int drv_set_cca_mode(struct ieee802154_local *local, - const struct wpan_phy_cca *cca) +static inline int +drv_set_promiscuous_mode(struct ieee802154_local *local, bool on) { int ret; might_sleep(); - if (!local->ops->set_cca_mode) { + if (!local->ops->set_promiscuous_mode) { WARN_ON(1); return -EOPNOTSUPP; } - trace_802154_drv_set_cca_mode(local, cca); - ret = local->ops->set_cca_mode(&local->hw, cca); + trace_802154_drv_set_promiscuous_mode(local, on); + ret = local->ops->set_promiscuous_mode(&local->hw, on); trace_802154_drv_return_int(local, ret); return ret; } -static inline int drv_set_lbt_mode(struct ieee802154_local *local, bool mode) +static inline int drv_start(struct ieee802154_local *local, + enum ieee802154_filtering_level level, + const struct ieee802154_hw_addr_filt *addr_filt) { int ret; might_sleep(); - if (!local->ops->set_lbt) { + /* setup receive mode parameters e.g. address mode */ + if (local->hw.flags & IEEE802154_HW_AFILT) { + ret = drv_set_pan_id(local, addr_filt->pan_id); + if (ret < 0) + return ret; + + ret = drv_set_short_addr(local, addr_filt->short_addr); + if (ret < 0) + return ret; + + ret = drv_set_extended_addr(local, addr_filt->ieee_addr); + if (ret < 0) + return ret; + } + + switch (level) { + case IEEE802154_FILTERING_NONE: + fallthrough; + case IEEE802154_FILTERING_1_FCS: + fallthrough; + case IEEE802154_FILTERING_2_PROMISCUOUS: + /* TODO: Requires a different receive mode setup e.g. + * at86rf233 hardware. + */ + fallthrough; + case IEEE802154_FILTERING_3_SCAN: + if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { + ret = drv_set_promiscuous_mode(local, true); + if (ret < 0) + return ret; + } else { + return -EOPNOTSUPP; + } + + /* In practice other filtering levels can be requested, but as + * for now most hardware/drivers only support + * IEEE802154_FILTERING_NONE, we fallback to this actual + * filtering level in hardware and make our own additional + * filtering in mac802154 receive path. + * + * TODO: Move this logic to the device drivers as hardware may + * support more higher level filters. Hardware may also require + * a different order how register are set, which could currently + * be buggy, so all received parameters need to be moved to the + * start() callback and let the driver go into the mode before + * it will turn on receive handling. + */ + local->phy->filtering = IEEE802154_FILTERING_NONE; + break; + case IEEE802154_FILTERING_4_FRAME_FIELDS: + /* Do not error out if IEEE802154_HW_PROMISCUOUS because we + * expect the hardware to operate at the level + * IEEE802154_FILTERING_4_FRAME_FIELDS anyway. + */ + if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { + ret = drv_set_promiscuous_mode(local, false); + if (ret < 0) + return ret; + } + + local->phy->filtering = IEEE802154_FILTERING_4_FRAME_FIELDS; + break; + default: WARN_ON(1); - return -EOPNOTSUPP; + return -EINVAL; } - trace_802154_drv_set_lbt_mode(local, mode); - ret = local->ops->set_lbt(&local->hw, mode); + trace_802154_drv_start(local); + local->started = true; + smp_mb(); + ret = local->ops->start(&local->hw); trace_802154_drv_return_int(local, ret); return ret; } +static inline void drv_stop(struct ieee802154_local *local) +{ + might_sleep(); + + trace_802154_drv_stop(local); + local->ops->stop(&local->hw); + trace_802154_drv_return_void(local); + + /* sync away all work on the tasklet before clearing started */ + tasklet_disable(&local->tasklet); + tasklet_enable(&local->tasklet); + + barrier(); + + local->started = false; +} + static inline int -drv_set_cca_ed_level(struct ieee802154_local *local, s32 mbm) +drv_set_channel(struct ieee802154_local *local, u8 page, u8 channel) { int ret; might_sleep(); - if (!local->ops->set_cca_ed_level) { - WARN_ON(1); - return -EOPNOTSUPP; - } - - trace_802154_drv_set_cca_ed_level(local, mbm); - ret = local->ops->set_cca_ed_level(&local->hw, mbm); + trace_802154_drv_set_channel(local, page, channel); + ret = local->ops->set_channel(&local->hw, page, channel); trace_802154_drv_return_int(local, ret); return ret; } -static inline int drv_set_pan_id(struct ieee802154_local *local, __le16 pan_id) +static inline int drv_set_tx_power(struct ieee802154_local *local, s32 mbm) { - struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - if (!local->ops->set_hw_addr_filt) { + if (!local->ops->set_txpower) { WARN_ON(1); return -EOPNOTSUPP; } - filt.pan_id = pan_id; - - trace_802154_drv_set_pan_id(local, pan_id); - ret = local->ops->set_hw_addr_filt(&local->hw, &filt, - IEEE802154_AFILT_PANID_CHANGED); + trace_802154_drv_set_tx_power(local, mbm); + ret = local->ops->set_txpower(&local->hw, mbm); trace_802154_drv_return_int(local, ret); return ret; } -static inline int -drv_set_extended_addr(struct ieee802154_local *local, __le64 extended_addr) +static inline int drv_set_cca_mode(struct ieee802154_local *local, + const struct wpan_phy_cca *cca) { - struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - if (!local->ops->set_hw_addr_filt) { + if (!local->ops->set_cca_mode) { WARN_ON(1); return -EOPNOTSUPP; } - filt.ieee_addr = extended_addr; - - trace_802154_drv_set_extended_addr(local, extended_addr); - ret = local->ops->set_hw_addr_filt(&local->hw, &filt, - IEEE802154_AFILT_IEEEADDR_CHANGED); + trace_802154_drv_set_cca_mode(local, cca); + ret = local->ops->set_cca_mode(&local->hw, cca); trace_802154_drv_return_int(local, ret); return ret; } -static inline int -drv_set_short_addr(struct ieee802154_local *local, __le16 short_addr) +static inline int drv_set_lbt_mode(struct ieee802154_local *local, bool mode) { - struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - if (!local->ops->set_hw_addr_filt) { + if (!local->ops->set_lbt) { WARN_ON(1); return -EOPNOTSUPP; } - filt.short_addr = short_addr; - - trace_802154_drv_set_short_addr(local, short_addr); - ret = local->ops->set_hw_addr_filt(&local->hw, &filt, - IEEE802154_AFILT_SADDR_CHANGED); + trace_802154_drv_set_lbt_mode(local, mode); + ret = local->ops->set_lbt(&local->hw, mode); trace_802154_drv_return_int(local, ret); return ret; } static inline int -drv_set_pan_coord(struct ieee802154_local *local, bool is_coord) +drv_set_cca_ed_level(struct ieee802154_local *local, s32 mbm) { - struct ieee802154_hw_addr_filt filt; int ret; might_sleep(); - if (!local->ops->set_hw_addr_filt) { + if (!local->ops->set_cca_ed_level) { WARN_ON(1); return -EOPNOTSUPP; } - filt.pan_coord = is_coord; - - trace_802154_drv_set_pan_coord(local, is_coord); - ret = local->ops->set_hw_addr_filt(&local->hw, &filt, - IEEE802154_AFILT_PANC_CHANGED); + trace_802154_drv_set_cca_ed_level(local, mbm); + ret = local->ops->set_cca_ed_level(&local->hw, mbm); trace_802154_drv_return_int(local, ret); return ret; } @@ -264,22 +351,4 @@ drv_set_max_frame_retries(struct ieee802154_local *local, s8 max_frame_retries) return ret; } -static inline int -drv_set_promiscuous_mode(struct ieee802154_local *local, bool on) -{ - int ret; - - might_sleep(); - - if (!local->ops->set_promiscuous_mode) { - WARN_ON(1); - return -EOPNOTSUPP; - } - - trace_802154_drv_set_promiscuous_mode(local, on); - ret = local->ops->set_promiscuous_mode(&local->hw, on); - trace_802154_drv_return_int(local, ret); - return ret; -} - #endif /* __MAC802154_DRIVER_OPS */ diff --git a/net/mac802154/ieee802154_i.h b/net/mac802154/ieee802154_i.h index 1381e6a5e180..509e0172fe82 100644 --- a/net/mac802154/ieee802154_i.h +++ b/net/mac802154/ieee802154_i.h @@ -26,6 +26,8 @@ struct ieee802154_local { struct ieee802154_hw hw; const struct ieee802154_ops *ops; + /* hardware address filter */ + struct ieee802154_hw_addr_filt addr_filt; /* ieee802154 phy */ struct wpan_phy *phy; @@ -55,7 +57,7 @@ struct ieee802154_local { struct sk_buff_head skb_queue; struct sk_buff *tx_skb; - struct work_struct tx_work; + struct work_struct sync_tx_work; /* A negative Linux error code or a null/positive MLME error status */ int tx_result; }; @@ -82,6 +84,16 @@ struct ieee802154_sub_if_data { struct ieee802154_local *local; struct net_device *dev; + /* Each interface starts and works in nominal state at a given filtering + * level given by iface_default_filtering, which is set once for all at + * the interface creation and should not evolve over time. For some MAC + * operations however, the filtering level may change temporarily, as + * reflected in the required_filtering field. The actual filtering at + * the PHY level may be different and is shown in struct wpan_phy. + */ + enum ieee802154_filtering_level iface_default_filtering; + enum ieee802154_filtering_level required_filtering; + unsigned long state; char name[IFNAMSIZ]; @@ -123,13 +135,53 @@ ieee802154_sdata_running(struct ieee802154_sub_if_data *sdata) extern struct ieee802154_mlme_ops mac802154_mlme_wpan; void ieee802154_rx(struct ieee802154_local *local, struct sk_buff *skb); -void ieee802154_xmit_worker(struct work_struct *work); +void ieee802154_xmit_sync_worker(struct work_struct *work); +int ieee802154_sync_and_hold_queue(struct ieee802154_local *local); +int ieee802154_mlme_op_pre(struct ieee802154_local *local); +int ieee802154_mlme_tx(struct ieee802154_local *local, + struct ieee802154_sub_if_data *sdata, + struct sk_buff *skb); +void ieee802154_mlme_op_post(struct ieee802154_local *local); +int ieee802154_mlme_tx_one(struct ieee802154_local *local, + struct ieee802154_sub_if_data *sdata, + struct sk_buff *skb); netdev_tx_t ieee802154_monitor_start_xmit(struct sk_buff *skb, struct net_device *dev); netdev_tx_t ieee802154_subif_start_xmit(struct sk_buff *skb, struct net_device *dev); enum hrtimer_restart ieee802154_xmit_ifs_timer(struct hrtimer *timer); +/** + * ieee802154_hold_queue - hold ieee802154 queue + * @local: main mac object + * + * Hold a queue by incrementing an atomic counter and requesting the netif + * queues to be stopped. The queues cannot be woken up while the counter has not + * been reset with as any ieee802154_release_queue() calls as needed. + */ +void ieee802154_hold_queue(struct ieee802154_local *local); + +/** + * ieee802154_release_queue - release ieee802154 queue + * @local: main mac object + * + * Release a queue which is held by decrementing an atomic counter and wake it + * up only if the counter reaches 0. + */ +void ieee802154_release_queue(struct ieee802154_local *local); + +/** + * ieee802154_disable_queue - disable ieee802154 queue + * @local: main mac object + * + * When trying to sync the Tx queue, we cannot just stop the queue + * (which is basically a bit being set without proper lock handling) + * because it would be racy. We actually need to call netif_tx_disable() + * instead, which is done by this helper. Restarting the queue can + * however still be done with a regular wake call. + */ +void ieee802154_disable_queue(struct ieee802154_local *local); + /* MIB callbacks */ void mac802154_dev_set_page_channel(struct net_device *dev, u8 page, u8 chan); diff --git a/net/mac802154/iface.c b/net/mac802154/iface.c index 500ed1b81250..ac0b28025fb0 100644 --- a/net/mac802154/iface.c +++ b/net/mac802154/iface.c @@ -147,25 +147,12 @@ static int ieee802154_setup_hw(struct ieee802154_sub_if_data *sdata) struct wpan_dev *wpan_dev = &sdata->wpan_dev; int ret; - if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { - ret = drv_set_promiscuous_mode(local, - wpan_dev->promiscuous_mode); - if (ret < 0) - return ret; - } + sdata->required_filtering = sdata->iface_default_filtering; if (local->hw.flags & IEEE802154_HW_AFILT) { - ret = drv_set_pan_id(local, wpan_dev->pan_id); - if (ret < 0) - return ret; - - ret = drv_set_extended_addr(local, wpan_dev->extended_addr); - if (ret < 0) - return ret; - - ret = drv_set_short_addr(local, wpan_dev->short_addr); - if (ret < 0) - return ret; + local->addr_filt.pan_id = wpan_dev->pan_id; + local->addr_filt.ieee_addr = wpan_dev->extended_addr; + local->addr_filt.short_addr = wpan_dev->short_addr; } if (local->hw.flags & IEEE802154_HW_LBT) { @@ -206,7 +193,8 @@ static int mac802154_slave_open(struct net_device *dev) if (res) goto err; - res = drv_start(local); + res = drv_start(local, sdata->required_filtering, + &local->addr_filt); if (res) goto err; } @@ -223,15 +211,16 @@ err: static int ieee802154_check_mac_settings(struct ieee802154_local *local, - struct wpan_dev *wpan_dev, - struct wpan_dev *nwpan_dev) + struct ieee802154_sub_if_data *sdata, + struct ieee802154_sub_if_data *nsdata) { + struct wpan_dev *nwpan_dev = &nsdata->wpan_dev; + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + ASSERT_RTNL(); - if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { - if (wpan_dev->promiscuous_mode != nwpan_dev->promiscuous_mode) - return -EBUSY; - } + if (sdata->iface_default_filtering != nsdata->iface_default_filtering) + return -EBUSY; if (local->hw.flags & IEEE802154_HW_AFILT) { if (wpan_dev->pan_id != nwpan_dev->pan_id || @@ -265,7 +254,6 @@ ieee802154_check_concurrent_iface(struct ieee802154_sub_if_data *sdata, enum nl802154_iftype iftype) { struct ieee802154_local *local = sdata->local; - struct wpan_dev *wpan_dev = &sdata->wpan_dev; struct ieee802154_sub_if_data *nsdata; /* we hold the RTNL here so can safely walk the list */ @@ -273,20 +261,19 @@ ieee802154_check_concurrent_iface(struct ieee802154_sub_if_data *sdata, if (nsdata != sdata && ieee802154_sdata_running(nsdata)) { int ret; - /* TODO currently we don't support multiple node types - * we need to run skb_clone at rx path. Check if there - * exist really an use case if we need to support - * multiple node types at the same time. + /* TODO currently we don't support multiple node/coord + * types we need to run skb_clone at rx path. Check if + * there exist really an use case if we need to support + * multiple node/coord types at the same time. */ - if (wpan_dev->iftype == NL802154_IFTYPE_NODE && - nsdata->wpan_dev.iftype == NL802154_IFTYPE_NODE) + if (sdata->wpan_dev.iftype != NL802154_IFTYPE_MONITOR && + nsdata->wpan_dev.iftype != NL802154_IFTYPE_MONITOR) return -EBUSY; /* check all phy mac sublayer settings are the same. * We have only one phy, different values makes trouble. */ - ret = ieee802154_check_mac_settings(local, wpan_dev, - &nsdata->wpan_dev); + ret = ieee802154_check_mac_settings(local, sdata, nsdata); if (ret < 0) return ret; } @@ -577,6 +564,7 @@ ieee802154_setup_sdata(struct ieee802154_sub_if_data *sdata, wpan_dev->short_addr = cpu_to_le16(IEEE802154_ADDR_BROADCAST); switch (type) { + case NL802154_IFTYPE_COORD: case NL802154_IFTYPE_NODE: ieee802154_be64_to_le64(&wpan_dev->extended_addr, sdata->dev->dev_addr); @@ -586,7 +574,7 @@ ieee802154_setup_sdata(struct ieee802154_sub_if_data *sdata, sdata->dev->priv_destructor = mac802154_wpan_free; sdata->dev->netdev_ops = &mac802154_wpan_ops; sdata->dev->ml_priv = &mac802154_mlme_wpan; - wpan_dev->promiscuous_mode = false; + sdata->iface_default_filtering = IEEE802154_FILTERING_4_FRAME_FIELDS; wpan_dev->header_ops = &ieee802154_header_ops; mutex_init(&sdata->sec_mtx); @@ -600,7 +588,7 @@ ieee802154_setup_sdata(struct ieee802154_sub_if_data *sdata, case NL802154_IFTYPE_MONITOR: sdata->dev->needs_free_netdev = true; sdata->dev->netdev_ops = &mac802154_monitor_ops; - wpan_dev->promiscuous_mode = true; + sdata->iface_default_filtering = IEEE802154_FILTERING_NONE; break; default: BUG(); @@ -636,6 +624,7 @@ ieee802154_if_add(struct ieee802154_local *local, const char *name, ieee802154_le64_to_be64(ndev->perm_addr, &local->hw.phy->perm_extended_addr); switch (type) { + case NL802154_IFTYPE_COORD: case NL802154_IFTYPE_NODE: ndev->type = ARPHRD_IEEE802154; if (ieee802154_is_valid_extended_unicast_addr(extended_addr)) { @@ -662,6 +651,7 @@ ieee802154_if_add(struct ieee802154_local *local, const char *name, sdata->dev = ndev; sdata->wpan_dev.wpan_phy = local->hw.phy; sdata->local = local; + INIT_LIST_HEAD(&sdata->wpan_dev.list); /* setup type-dependent data */ ret = ieee802154_setup_sdata(sdata, type); diff --git a/net/mac802154/main.c b/net/mac802154/main.c index bd7bdb1219dd..3ed31daf7b9c 100644 --- a/net/mac802154/main.c +++ b/net/mac802154/main.c @@ -95,7 +95,7 @@ ieee802154_alloc_hw(size_t priv_data_len, const struct ieee802154_ops *ops) skb_queue_head_init(&local->skb_queue); - INIT_WORK(&local->tx_work, ieee802154_xmit_worker); + INIT_WORK(&local->sync_tx_work, ieee802154_xmit_sync_worker); /* init supported flags with 802.15.4 default ranges */ phy->supported.max_minbe = 8; @@ -107,7 +107,7 @@ ieee802154_alloc_hw(size_t priv_data_len, const struct ieee802154_ops *ops) phy->supported.lbt = NL802154_SUPPORTED_BOOL_FALSE; /* always supported */ - phy->supported.iftypes = BIT(NL802154_IFTYPE_NODE); + phy->supported.iftypes = BIT(NL802154_IFTYPE_NODE) | BIT(NL802154_IFTYPE_COORD); return &local->hw; } diff --git a/net/mac802154/rx.c b/net/mac802154/rx.c index 726b47a4611b..97bb4401dd3e 100644 --- a/net/mac802154/rx.c +++ b/net/mac802154/rx.c @@ -34,6 +34,7 @@ ieee802154_subif_frame(struct ieee802154_sub_if_data *sdata, struct sk_buff *skb, const struct ieee802154_hdr *hdr) { struct wpan_dev *wpan_dev = &sdata->wpan_dev; + struct wpan_phy *wpan_phy = sdata->local->hw.phy; __le16 span, sshort; int rc; @@ -42,6 +43,17 @@ ieee802154_subif_frame(struct ieee802154_sub_if_data *sdata, span = wpan_dev->pan_id; sshort = wpan_dev->short_addr; + /* Level 3 filtering: Only beacons are accepted during scans */ + if (sdata->required_filtering == IEEE802154_FILTERING_3_SCAN && + sdata->required_filtering > wpan_phy->filtering) { + if (mac_cb(skb)->type != IEEE802154_FC_TYPE_BEACON) { + dev_dbg(&sdata->dev->dev, + "drop non-beacon frame (0x%x) during scan\n", + mac_cb(skb)->type); + goto fail; + } + } + switch (mac_cb(skb)->dest.mode) { case IEEE802154_ADDR_NONE: if (hdr->source.mode != IEEE802154_ADDR_NONE) @@ -114,8 +126,10 @@ fail: static void ieee802154_print_addr(const char *name, const struct ieee802154_addr *addr) { - if (addr->mode == IEEE802154_ADDR_NONE) + if (addr->mode == IEEE802154_ADDR_NONE) { pr_debug("%s not present\n", name); + return; + } pr_debug("%s PAN ID: %04x\n", name, le16_to_cpu(addr->pan_id)); if (addr->mode == IEEE802154_ADDR_SHORT) { @@ -194,27 +208,34 @@ __ieee802154_rx_handle_packet(struct ieee802154_local *local, int ret; struct ieee802154_sub_if_data *sdata; struct ieee802154_hdr hdr; + struct sk_buff *skb2; ret = ieee802154_parse_frame_start(skb, &hdr); if (ret) { pr_debug("got invalid frame\n"); - kfree_skb(skb); return; } list_for_each_entry_rcu(sdata, &local->interfaces, list) { - if (sdata->wpan_dev.iftype != NL802154_IFTYPE_NODE) + if (sdata->wpan_dev.iftype == NL802154_IFTYPE_MONITOR) continue; if (!ieee802154_sdata_running(sdata)) continue; - ieee802154_subif_frame(sdata, skb, &hdr); - skb = NULL; - break; - } + /* Do not deliver packets received on interfaces expecting + * AACK=1 if the address filters where disabled. + */ + if (local->hw.phy->filtering < IEEE802154_FILTERING_4_FRAME_FIELDS && + sdata->required_filtering == IEEE802154_FILTERING_4_FRAME_FIELDS) + continue; - kfree_skb(skb); + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) { + skb2->dev = sdata->dev; + ieee802154_subif_frame(sdata, skb2, &hdr); + } + } } static void @@ -253,7 +274,7 @@ void ieee802154_rx(struct ieee802154_local *local, struct sk_buff *skb) WARN_ON_ONCE(softirq_count() == 0); if (local->suspended) - goto drop; + goto free_skb; /* TODO: When a transceiver omits the checksum here, we * add an own calculated one. This is currently an ugly @@ -268,25 +289,20 @@ void ieee802154_rx(struct ieee802154_local *local, struct sk_buff *skb) ieee802154_monitors_rx(local, skb); - /* Check if transceiver doesn't validate the checksum. - * If not we validate the checksum here. - */ - if (local->hw.flags & IEEE802154_HW_RX_DROP_BAD_CKSUM) { + /* Level 1 filtering: Check the FCS by software when relevant */ + if (local->hw.phy->filtering == IEEE802154_FILTERING_NONE) { crc = crc_ccitt(0, skb->data, skb->len); - if (crc) { - rcu_read_unlock(); + if (crc) goto drop; - } } /* remove crc */ skb_trim(skb, skb->len - 2); __ieee802154_rx_handle_packet(local, skb); - rcu_read_unlock(); - - return; drop: + rcu_read_unlock(); +free_skb: kfree_skb(skb); } diff --git a/net/mac802154/trace.h b/net/mac802154/trace.h index df855c33daf2..689396d6c76a 100644 --- a/net/mac802154/trace.h +++ b/net/mac802154/trace.h @@ -264,6 +264,31 @@ TRACE_EVENT(802154_drv_set_promiscuous_mode, BOOL_TO_STR(__entry->on)) ); +TRACE_EVENT(802154_new_scan_event, + TP_PROTO(struct ieee802154_coord_desc *desc), + TP_ARGS(desc), + TP_STRUCT__entry( + __field(__le16, pan_id) + __field(__le64, addr) + __field(u8, channel) + __field(u8, page) + ), + TP_fast_assign( + __entry->page = desc->page; + __entry->channel = desc->channel; + __entry->pan_id = desc->addr.pan_id; + __entry->addr = desc->addr.extended_addr; + ), + TP_printk("panid: %u, coord_addr: 0x%llx, page: %u, channel: %u", + __le16_to_cpu(__entry->pan_id), __le64_to_cpu(__entry->addr), + __entry->page, __entry->channel) +); + +DEFINE_EVENT(802154_new_scan_event, 802154_scan_event, + TP_PROTO(struct ieee802154_coord_desc *desc), + TP_ARGS(desc) +); + #endif /* !__MAC802154_DRIVER_TRACE || TRACE_HEADER_MULTI_READ */ #undef TRACE_INCLUDE_PATH diff --git a/net/mac802154/tx.c b/net/mac802154/tx.c index c829e4a75325..9d8d43cf1e64 100644 --- a/net/mac802154/tx.c +++ b/net/mac802154/tx.c @@ -22,10 +22,10 @@ #include "ieee802154_i.h" #include "driver-ops.h" -void ieee802154_xmit_worker(struct work_struct *work) +void ieee802154_xmit_sync_worker(struct work_struct *work) { struct ieee802154_local *local = - container_of(work, struct ieee802154_local, tx_work); + container_of(work, struct ieee802154_local, sync_tx_work); struct sk_buff *skb = local->tx_skb; struct net_device *dev = skb->dev; int res; @@ -43,7 +43,9 @@ void ieee802154_xmit_worker(struct work_struct *work) err_tx: /* Restart the netif queue on each sub_if_data object. */ - ieee802154_wake_queue(&local->hw); + ieee802154_release_queue(local); + if (atomic_dec_and_test(&local->phy->ongoing_txs)) + wake_up(&local->phy->sync_txq); kfree_skb(skb); netdev_dbg(dev, "transmission failed\n"); } @@ -65,7 +67,7 @@ ieee802154_tx(struct ieee802154_local *local, struct sk_buff *skb) consume_skb(skb); skb = nskb; } else { - goto err_tx; + goto err_free_skb; } } @@ -74,32 +76,134 @@ ieee802154_tx(struct ieee802154_local *local, struct sk_buff *skb) } /* Stop the netif queue on each sub_if_data object. */ - ieee802154_stop_queue(&local->hw); + ieee802154_hold_queue(local); + atomic_inc(&local->phy->ongoing_txs); - /* async is priority, otherwise sync is fallback */ + /* Drivers should preferably implement the async callback. In some rare + * cases they only provide a sync callback which we will use as a + * fallback. + */ if (local->ops->xmit_async) { unsigned int len = skb->len; ret = drv_xmit_async(local, skb); - if (ret) { - ieee802154_wake_queue(&local->hw); - goto err_tx; - } + if (ret) + goto err_wake_netif_queue; dev->stats.tx_packets++; dev->stats.tx_bytes += len; } else { local->tx_skb = skb; - queue_work(local->workqueue, &local->tx_work); + queue_work(local->workqueue, &local->sync_tx_work); } return NETDEV_TX_OK; -err_tx: +err_wake_netif_queue: + ieee802154_release_queue(local); + if (atomic_dec_and_test(&local->phy->ongoing_txs)) + wake_up(&local->phy->sync_txq); +err_free_skb: kfree_skb(skb); return NETDEV_TX_OK; } +static int ieee802154_sync_queue(struct ieee802154_local *local) +{ + int ret; + + ieee802154_hold_queue(local); + ieee802154_disable_queue(local); + wait_event(local->phy->sync_txq, !atomic_read(&local->phy->ongoing_txs)); + ret = local->tx_result; + ieee802154_release_queue(local); + + return ret; +} + +int ieee802154_sync_and_hold_queue(struct ieee802154_local *local) +{ + int ret; + + ieee802154_hold_queue(local); + ret = ieee802154_sync_queue(local); + set_bit(WPAN_PHY_FLAG_STATE_QUEUE_STOPPED, &local->phy->flags); + + return ret; +} + +int ieee802154_mlme_op_pre(struct ieee802154_local *local) +{ + return ieee802154_sync_and_hold_queue(local); +} + +int ieee802154_mlme_tx(struct ieee802154_local *local, + struct ieee802154_sub_if_data *sdata, + struct sk_buff *skb) +{ + int ret; + + /* Avoid possible calls to ->ndo_stop() when we asynchronously perform + * MLME transmissions. + */ + rtnl_lock(); + + /* Ensure the device was not stopped, otherwise error out */ + if (!local->open_count) { + rtnl_unlock(); + return -ENETDOWN; + } + + /* Warn if the ieee802154 core thinks MLME frames can be sent while the + * net interface expects this cannot happen. + */ + if (WARN_ON_ONCE(!netif_running(sdata->dev))) { + rtnl_unlock(); + return -ENETDOWN; + } + + ieee802154_tx(local, skb); + ret = ieee802154_sync_queue(local); + + rtnl_unlock(); + + return ret; +} + +void ieee802154_mlme_op_post(struct ieee802154_local *local) +{ + ieee802154_release_queue(local); +} + +int ieee802154_mlme_tx_one(struct ieee802154_local *local, + struct ieee802154_sub_if_data *sdata, + struct sk_buff *skb) +{ + int ret; + + ieee802154_mlme_op_pre(local); + ret = ieee802154_mlme_tx(local, sdata, skb); + ieee802154_mlme_op_post(local); + + return ret; +} + +static bool ieee802154_queue_is_stopped(struct ieee802154_local *local) +{ + return test_bit(WPAN_PHY_FLAG_STATE_QUEUE_STOPPED, &local->phy->flags); +} + +static netdev_tx_t +ieee802154_hot_tx(struct ieee802154_local *local, struct sk_buff *skb) +{ + /* Warn if the net interface tries to transmit frames while the + * ieee802154 core assumes the queue is stopped. + */ + WARN_ON_ONCE(ieee802154_queue_is_stopped(local)); + + return ieee802154_tx(local, skb); +} + netdev_tx_t ieee802154_monitor_start_xmit(struct sk_buff *skb, struct net_device *dev) { @@ -107,7 +211,7 @@ ieee802154_monitor_start_xmit(struct sk_buff *skb, struct net_device *dev) skb->skb_iif = dev->ifindex; - return ieee802154_tx(sdata->local, skb); + return ieee802154_hot_tx(sdata->local, skb); } netdev_tx_t @@ -129,5 +233,5 @@ ieee802154_subif_start_xmit(struct sk_buff *skb, struct net_device *dev) skb->skb_iif = dev->ifindex; - return ieee802154_tx(sdata->local, skb); + return ieee802154_hot_tx(sdata->local, skb); } diff --git a/net/mac802154/util.c b/net/mac802154/util.c index 9f024d85563b..ebc9a8521765 100644 --- a/net/mac802154/util.c +++ b/net/mac802154/util.c @@ -13,12 +13,23 @@ /* privid for wpan_phys to determine whether they belong to us or not */ const void *const mac802154_wpan_phy_privid = &mac802154_wpan_phy_privid; -void ieee802154_wake_queue(struct ieee802154_hw *hw) +/** + * ieee802154_wake_queue - wake ieee802154 queue + * @hw: main hardware object + * + * Tranceivers usually have either one transmit framebuffer or one framebuffer + * for both transmitting and receiving. Hence, the core currently only handles + * one frame at a time for each phy, which means we had to stop the queue to + * avoid new skb to come during the transmission. The queue then needs to be + * woken up after the operation. + */ +static void ieee802154_wake_queue(struct ieee802154_hw *hw) { struct ieee802154_local *local = hw_to_local(hw); struct ieee802154_sub_if_data *sdata; rcu_read_lock(); + clear_bit(WPAN_PHY_FLAG_STATE_QUEUE_STOPPED, &local->phy->flags); list_for_each_entry_rcu(sdata, &local->interfaces, list) { if (!sdata->dev) continue; @@ -27,9 +38,18 @@ void ieee802154_wake_queue(struct ieee802154_hw *hw) } rcu_read_unlock(); } -EXPORT_SYMBOL(ieee802154_wake_queue); -void ieee802154_stop_queue(struct ieee802154_hw *hw) +/** + * ieee802154_stop_queue - stop ieee802154 queue + * @hw: main hardware object + * + * Tranceivers usually have either one transmit framebuffer or one framebuffer + * for both transmitting and receiving. Hence, the core currently only handles + * one frame at a time for each phy, which means we need to tell upper layers to + * stop giving us new skbs while we are busy with the transmitted one. The queue + * must then be stopped before transmitting. + */ +static void ieee802154_stop_queue(struct ieee802154_hw *hw) { struct ieee802154_local *local = hw_to_local(hw); struct ieee802154_sub_if_data *sdata; @@ -43,14 +63,47 @@ void ieee802154_stop_queue(struct ieee802154_hw *hw) } rcu_read_unlock(); } -EXPORT_SYMBOL(ieee802154_stop_queue); + +void ieee802154_hold_queue(struct ieee802154_local *local) +{ + unsigned long flags; + + spin_lock_irqsave(&local->phy->queue_lock, flags); + if (!atomic_fetch_inc(&local->phy->hold_txs)) + ieee802154_stop_queue(&local->hw); + spin_unlock_irqrestore(&local->phy->queue_lock, flags); +} + +void ieee802154_release_queue(struct ieee802154_local *local) +{ + unsigned long flags; + + spin_lock_irqsave(&local->phy->queue_lock, flags); + if (atomic_dec_and_test(&local->phy->hold_txs)) + ieee802154_wake_queue(&local->hw); + spin_unlock_irqrestore(&local->phy->queue_lock, flags); +} + +void ieee802154_disable_queue(struct ieee802154_local *local) +{ + struct ieee802154_sub_if_data *sdata; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!sdata->dev) + continue; + + netif_tx_disable(sdata->dev); + } + rcu_read_unlock(); +} enum hrtimer_restart ieee802154_xmit_ifs_timer(struct hrtimer *timer) { struct ieee802154_local *local = container_of(timer, struct ieee802154_local, ifs_timer); - ieee802154_wake_queue(&local->hw); + ieee802154_release_queue(local); return HRTIMER_NORESTART; } @@ -84,10 +137,12 @@ void ieee802154_xmit_complete(struct ieee802154_hw *hw, struct sk_buff *skb, hw->phy->sifs_period * NSEC_PER_USEC, HRTIMER_MODE_REL); } else { - ieee802154_wake_queue(hw); + ieee802154_release_queue(local); } dev_consume_skb_any(skb); + if (atomic_dec_and_test(&hw->phy->ongoing_txs)) + wake_up(&hw->phy->sync_txq); } EXPORT_SYMBOL(ieee802154_xmit_complete); @@ -97,8 +152,10 @@ void ieee802154_xmit_error(struct ieee802154_hw *hw, struct sk_buff *skb, struct ieee802154_local *local = hw_to_local(hw); local->tx_result = reason; - ieee802154_wake_queue(hw); + ieee802154_release_queue(local); dev_kfree_skb_any(skb); + if (atomic_dec_and_test(&hw->phy->ongoing_txs)) + wake_up(&hw->phy->sync_txq); } EXPORT_SYMBOL(ieee802154_xmit_error); diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index fc9e728b6333..3150f3f0c872 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -544,9 +544,6 @@ static int mctp_sk_init(struct sock *sk) static void mctp_sk_close(struct sock *sk, long timeout) { - struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); - - del_timer_sync(&msk->key_expiry); sk_common_release(sk); } @@ -580,7 +577,19 @@ static void mctp_sk_unhash(struct sock *sk) spin_lock_irqsave(&key->lock, fl2); __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED); } + sock_set_flag(sk, SOCK_DEAD); spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + /* Since there are no more tag allocations (we have removed all of the + * keys), stop any pending expiry events. the timer cannot be re-queued + * as the sk is no longer observable + */ + del_timer_sync(&msk->key_expiry); +} + +static void mctp_sk_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); } static struct proto mctp_proto = { @@ -619,6 +628,7 @@ static int mctp_pf_create(struct net *net, struct socket *sock, return -ENOMEM; sock_init_data(sock, sk); + sk->sk_destruct = mctp_sk_destruct; rc = 0; if (sk->sk_prot->init) diff --git a/net/mctp/device.c b/net/mctp/device.c index 99a3bda8852f..acb97b257428 100644 --- a/net/mctp/device.c +++ b/net/mctp/device.c @@ -429,12 +429,6 @@ static void mctp_unregister(struct net_device *dev) struct mctp_dev *mdev; mdev = mctp_dev_get_rtnl(dev); - if (mdev && !mctp_known(dev)) { - // Sanity check, should match what was set in mctp_register - netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d", - __func__, dev->type); - return; - } if (!mdev) return; @@ -451,14 +445,8 @@ static int mctp_register(struct net_device *dev) struct mctp_dev *mdev; /* Already registered? */ - mdev = rtnl_dereference(dev->mctp_ptr); - - if (mdev) { - if (!mctp_known(dev)) - netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d", - __func__, dev->type); + if (rtnl_dereference(dev->mctp_ptr)) return 0; - } /* only register specific types */ if (!mctp_known(dev)) diff --git a/net/mctp/route.c b/net/mctp/route.c index f9a80b82dc51..f51a05ec7162 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -147,6 +147,7 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, key->valid = true; spin_lock_init(&key->lock); refcount_set(&key->refs, 1); + sock_hold(key->sk); return key; } @@ -165,6 +166,7 @@ void mctp_key_unref(struct mctp_sk_key *key) mctp_dev_release_key(key->dev, key); spin_unlock_irqrestore(&key->lock, flags); + sock_put(key->sk); kfree(key); } @@ -177,6 +179,11 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) spin_lock_irqsave(&net->mctp.keys_lock, flags); + if (sock_flag(&msk->sk, SOCK_DEAD)) { + rc = -EINVAL; + goto out_unlock; + } + hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { if (mctp_key_match(tmp, key->local_addr, key->peer_addr, key->tag)) { @@ -198,6 +205,7 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) hlist_add_head(&key->sklist, &msk->keys); } +out_unlock: spin_unlock_irqrestore(&net->mctp.keys_lock, flags); return rc; @@ -315,8 +323,8 @@ static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) { + struct mctp_sk_key *key, *any_key = NULL; struct net *net = dev_net(skb->dev); - struct mctp_sk_key *key; struct mctp_sock *msk; struct mctp_hdr *mh; unsigned long f; @@ -361,13 +369,11 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * key for reassembly - we'll create a more specific * one for future packets if required (ie, !EOM). */ - key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); - if (key) { - msk = container_of(key->sk, + any_key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); + if (any_key) { + msk = container_of(any_key->sk, struct mctp_sock, sk); - spin_unlock_irqrestore(&key->lock, f); - mctp_key_unref(key); - key = NULL; + spin_unlock_irqrestore(&any_key->lock, f); } } @@ -419,14 +425,14 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * this function. */ rc = mctp_key_add(key, msk); - if (rc) { - kfree(key); - } else { + if (!rc) trace_mctp_key_acquire(key); - /* we don't need to release key->lock on exit */ - mctp_key_unref(key); - } + /* we don't need to release key->lock on exit, so + * clean up here and suppress the unlock via + * setting to NULL + */ + mctp_key_unref(key); key = NULL; } else { @@ -473,6 +479,8 @@ out_unlock: spin_unlock_irqrestore(&key->lock, f); mctp_key_unref(key); } + if (any_key) + mctp_key_unref(any_key); out: if (rc) kfree_skb(skb); diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index b52afe316dc4..35b5f806fdda 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -1079,9 +1079,9 @@ static void mpls_get_stats(struct mpls_dev *mdev, p = per_cpu_ptr(mdev->stats, i); do { - start = u64_stats_fetch_begin_irq(&p->syncp); + start = u64_stats_fetch_begin(&p->syncp); local = p->stats; - } while (u64_stats_fetch_retry_irq(&p->syncp, start)); + } while (u64_stats_fetch_retry(&p->syncp, start)); stats->rx_packets += local.rx_packets; stats->rx_bytes += local.rx_bytes; diff --git a/net/mptcp/Makefile b/net/mptcp/Makefile index 6e7df47c9584..a3829ce548f9 100644 --- a/net/mptcp/Makefile +++ b/net/mptcp/Makefile @@ -2,7 +2,7 @@ obj-$(CONFIG_MPTCP) += mptcp.o mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \ - mib.o pm_netlink.o sockopt.o pm_userspace.o + mib.o pm_netlink.o sockopt.o pm_userspace.o fastopen.o obj-$(CONFIG_SYN_COOKIES) += syncookies.o obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o diff --git a/net/mptcp/fastopen.c b/net/mptcp/fastopen.c new file mode 100644 index 000000000000..d237d142171c --- /dev/null +++ b/net/mptcp/fastopen.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0 +/* MPTCP Fast Open Mechanism + * + * Copyright (c) 2021-2022, Dmytro SHYTYI + */ + +#include "protocol.h" + +void mptcp_fastopen_subflow_synack_set_params(struct mptcp_subflow_context *subflow, + struct request_sock *req) +{ + struct sock *ssk = subflow->tcp_sock; + struct sock *sk = subflow->conn; + struct sk_buff *skb; + struct tcp_sock *tp; + + tp = tcp_sk(ssk); + + subflow->is_mptfo = 1; + + skb = skb_peek(&ssk->sk_receive_queue); + if (WARN_ON_ONCE(!skb)) + return; + + /* dequeue the skb from sk receive queue */ + __skb_unlink(skb, &ssk->sk_receive_queue); + skb_ext_reset(skb); + skb_orphan(skb); + + /* We copy the fastopen data, but that don't belong to the mptcp sequence + * space, need to offset it in the subflow sequence, see mptcp_subflow_get_map_offset() + */ + tp->copied_seq += skb->len; + subflow->ssn_offset += skb->len; + + /* initialize a dummy sequence number, we will update it at MPC + * completion, if needed + */ + MPTCP_SKB_CB(skb)->map_seq = -skb->len; + MPTCP_SKB_CB(skb)->end_seq = 0; + MPTCP_SKB_CB(skb)->offset = 0; + MPTCP_SKB_CB(skb)->has_rxtstamp = TCP_SKB_CB(skb)->has_rxtstamp; + + mptcp_data_lock(sk); + + mptcp_set_owner_r(skb, sk); + __skb_queue_tail(&sk->sk_receive_queue, skb); + + sk->sk_data_ready(sk); + + mptcp_data_unlock(sk); +} + +void mptcp_fastopen_gen_msk_ackseq(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow, + const struct mptcp_options_received *mp_opt) +{ + struct sock *sk = (struct sock *)msk; + struct sk_buff *skb; + + mptcp_data_lock(sk); + skb = skb_peek_tail(&sk->sk_receive_queue); + if (skb) { + WARN_ON_ONCE(MPTCP_SKB_CB(skb)->end_seq); + pr_debug("msk %p moving seq %llx -> %llx end_seq %llx -> %llx", sk, + MPTCP_SKB_CB(skb)->map_seq, MPTCP_SKB_CB(skb)->map_seq + msk->ack_seq, + MPTCP_SKB_CB(skb)->end_seq, MPTCP_SKB_CB(skb)->end_seq + msk->ack_seq); + MPTCP_SKB_CB(skb)->map_seq += msk->ack_seq; + MPTCP_SKB_CB(skb)->end_seq += msk->ack_seq; + } + + pr_debug("msk=%p ack_seq=%llx", msk, msk->ack_seq); + mptcp_data_unlock(sk); +} diff --git a/net/mptcp/options.c b/net/mptcp/options.c index 30d289044e71..5ded85e2c374 100644 --- a/net/mptcp/options.c +++ b/net/mptcp/options.c @@ -26,6 +26,7 @@ static void mptcp_parse_option(const struct sk_buff *skb, { u8 subtype = *ptr >> 4; int expected_opsize; + u16 subopt; u8 version; u8 flags; u8 i; @@ -38,11 +39,15 @@ static void mptcp_parse_option(const struct sk_buff *skb, expected_opsize = TCPOLEN_MPTCP_MPC_ACK_DATA; else expected_opsize = TCPOLEN_MPTCP_MPC_ACK; + subopt = OPTION_MPTCP_MPC_ACK; } else { - if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACK) + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACK) { expected_opsize = TCPOLEN_MPTCP_MPC_SYNACK; - else + subopt = OPTION_MPTCP_MPC_SYNACK; + } else { expected_opsize = TCPOLEN_MPTCP_MPC_SYN; + subopt = OPTION_MPTCP_MPC_SYN; + } } /* Cfr RFC 8684 Section 3.3.0: @@ -85,7 +90,7 @@ static void mptcp_parse_option(const struct sk_buff *skb, mp_opt->deny_join_id0 = !!(flags & MPTCP_CAP_DENY_JOIN_ID0); - mp_opt->suboptions |= OPTIONS_MPTCP_MPC; + mp_opt->suboptions |= subopt; if (opsize >= TCPOLEN_MPTCP_MPC_SYNACK) { mp_opt->sndr_key = get_unaligned_be64(ptr); ptr += 8; @@ -934,7 +939,7 @@ static bool check_fully_established(struct mptcp_sock *msk, struct sock *ssk, subflow->mp_join && (mp_opt->suboptions & OPTIONS_MPTCP_MPJ) && !subflow->request_join) tcp_send_ack(ssk); - goto fully_established; + goto check_notify; } /* we must process OoO packets before the first subflow is fully @@ -945,17 +950,20 @@ static bool check_fully_established(struct mptcp_sock *msk, struct sock *ssk, if (TCP_SKB_CB(skb)->seq != subflow->ssn_offset + 1) { if (subflow->mp_join) goto reset; + if (subflow->is_mptfo && mp_opt->suboptions & OPTION_MPTCP_MPC_ACK) + goto set_fully_established; return subflow->mp_capable; } - if (((mp_opt->suboptions & OPTION_MPTCP_DSS) && mp_opt->use_ack) || - ((mp_opt->suboptions & OPTION_MPTCP_ADD_ADDR) && !mp_opt->echo)) { + if (subflow->remote_key_valid && + (((mp_opt->suboptions & OPTION_MPTCP_DSS) && mp_opt->use_ack) || + ((mp_opt->suboptions & OPTION_MPTCP_ADD_ADDR) && !mp_opt->echo))) { /* subflows are fully established as soon as we get any * additional ack, including ADD_ADDR. */ subflow->fully_established = 1; WRITE_ONCE(msk->fully_established, true); - goto fully_established; + goto check_notify; } /* If the first established packet does not contain MP_CAPABLE + data @@ -974,11 +982,12 @@ static bool check_fully_established(struct mptcp_sock *msk, struct sock *ssk, if (mp_opt->deny_join_id0) WRITE_ONCE(msk->pm.remote_deny_join_id0, true); +set_fully_established: if (unlikely(!READ_ONCE(msk->pm.server_side))) pr_warn_once("bogus mpc option on established client sk"); mptcp_subflow_fully_established(subflow, mp_opt); -fully_established: +check_notify: /* if the subflow is not already linked into the conn_list, we can't * notify the PM: this subflow is still on the listener queue * and the PM possibly acquiring the subflow lock could race with diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index 45e2a48397b9..70f0ced3ca86 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -420,6 +420,31 @@ void mptcp_pm_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk) } } +/* if sk is ipv4 or ipv6_only allows only same-family local and remote addresses, + * otherwise allow any matching local/remote pair + */ +bool mptcp_pm_addr_families_match(const struct sock *sk, + const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *rem) +{ + bool mptcp_is_v4 = sk->sk_family == AF_INET; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + bool loc_is_v4 = loc->family == AF_INET || ipv6_addr_v4mapped(&loc->addr6); + bool rem_is_v4 = rem->family == AF_INET || ipv6_addr_v4mapped(&rem->addr6); + + if (mptcp_is_v4) + return loc_is_v4 && rem_is_v4; + + if (ipv6_only_sock(sk)) + return !loc_is_v4 && !rem_is_v4; + + return loc_is_v4 == rem_is_v4; +#else + return mptcp_is_v4 && loc->family == AF_INET && rem->family == AF_INET; +#endif +} + void mptcp_pm_data_reset(struct mptcp_sock *msk) { u8 pm_type = mptcp_get_pm_type(sock_net((struct sock *)msk)); diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 9813ed0fde9b..2ea7eae43bdb 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -912,10 +912,14 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, */ if (pernet->next_id == MPTCP_PM_MAX_ADDR_ID) pernet->next_id = 1; - if (pernet->addrs >= MPTCP_PM_ADDR_MAX) + if (pernet->addrs >= MPTCP_PM_ADDR_MAX) { + ret = -ERANGE; goto out; - if (test_bit(entry->addr.id, pernet->id_bitmap)) + } + if (test_bit(entry->addr.id, pernet->id_bitmap)) { + ret = -EBUSY; goto out; + } /* do not insert duplicate address, differentiate on port only * singled addresses @@ -929,8 +933,10 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, * endpoint is an implicit one and the user-space * did not provide an endpoint id */ - if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)) + if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)) { + ret = -EEXIST; goto out; + } if (entry->addr.id) goto out; @@ -1003,16 +1009,12 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, return err; msk = mptcp_sk(entry->lsk->sk); - if (!msk) { - err = -EINVAL; - goto out; - } + if (!msk) + return -EINVAL; ssock = __mptcp_nmpc_socket(msk); - if (!ssock) { - err = -EINVAL; - goto out; - } + if (!ssock) + return -EINVAL; mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); #if IS_ENABLED(CONFIG_MPTCP_IPV6) @@ -1020,22 +1022,16 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, addrlen = sizeof(struct sockaddr_in6); #endif err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen); - if (err) { - pr_warn("kernel_bind error, err=%d", err); - goto out; - } + if (err) + return err; err = kernel_listen(ssock, backlog); - if (err) { - pr_warn("kernel_listen error, err=%d", err); - goto out; - } + if (err) + return err; - return 0; + mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); -out: - sock_release(entry->lsk); - return err; + return 0; } int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) @@ -1194,7 +1190,7 @@ static int mptcp_pm_parse_pm_addr_attr(struct nlattr *tb[], if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { if (!require_family) - return err; + return 0; NL_SET_ERR_MSG_ATTR(info->extack, attr, "missing family"); @@ -1228,7 +1224,7 @@ static int mptcp_pm_parse_pm_addr_attr(struct nlattr *tb[], if (tb[MPTCP_PM_ADDR_ATTR_PORT]) addr->port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); - return err; + return 0; } int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, @@ -1327,7 +1323,7 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) return -EINVAL; } - entry = kmalloc(sizeof(*entry), GFP_KERNEL_ACCOUNT); + entry = kzalloc(sizeof(*entry), GFP_KERNEL_ACCOUNT); if (!entry) { GENL_SET_ERR_MSG(info, "can't allocate addr"); return -ENOMEM; @@ -1337,23 +1333,22 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) if (entry->addr.port) { ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); if (ret) { - GENL_SET_ERR_MSG(info, "create listen socket error"); - kfree(entry); - return ret; + GENL_SET_ERR_MSG_FMT(info, "create listen socket error: %d", ret); + goto out_free; } } ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); if (ret < 0) { - GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); - if (entry->lsk) - sock_release(entry->lsk); - kfree(entry); - return ret; + GENL_SET_ERR_MSG_FMT(info, "too many addresses or duplicate one: %d", ret); + goto out_free; } mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); - return 0; + +out_free: + __mptcp_pm_release_addr_entry(entry); + return ret; } int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id, @@ -2099,7 +2094,7 @@ void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id) return; nla_put_failure: - kfree_skb(skb); + nlmsg_free(skb); } void mptcp_event_addr_announced(const struct sock *ssk, @@ -2156,7 +2151,59 @@ void mptcp_event_addr_announced(const struct sock *ssk, return; nla_put_failure: - kfree_skb(skb); + nlmsg_free(skb); +} + +void mptcp_event_pm_listener(const struct sock *ssk, + enum mptcp_event_type event) +{ + const struct inet_sock *issk = inet_sk(ssk); + struct net *net = sock_net(ssk); + struct nlmsghdr *nlh; + struct sk_buff *skb; + + if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) + return; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return; + + nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, event); + if (!nlh) + goto nla_put_failure; + + if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family)) + goto nla_put_failure; + + if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport)) + goto nla_put_failure; + + switch (ssk->sk_family) { + case AF_INET: + if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr)) + goto nla_put_failure; + break; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + case AF_INET6: { + const struct ipv6_pinfo *np = inet6_sk(ssk); + + if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr)) + goto nla_put_failure; + break; + } +#endif + default: + WARN_ON_ONCE(1); + goto nla_put_failure; + } + + genlmsg_end(skb, nlh); + mptcp_nl_mcast_send(net, skb, GFP_KERNEL); + return; + +nla_put_failure: + nlmsg_free(skb); } void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, @@ -2204,6 +2251,9 @@ void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, if (mptcp_event_sub_closed(skb, msk, ssk) < 0) goto nla_put_failure; break; + case MPTCP_EVENT_LISTENER_CREATED: + case MPTCP_EVENT_LISTENER_CLOSED: + break; } genlmsg_end(skb, nlh); @@ -2211,7 +2261,7 @@ void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, return; nla_put_failure: - kfree_skb(skb); + nlmsg_free(skb); } static const struct genl_small_ops mptcp_pm_ops[] = { diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c index 9e82250cbb70..ea6ad9da7493 100644 --- a/net/mptcp/pm_userspace.c +++ b/net/mptcp/pm_userspace.c @@ -156,6 +156,7 @@ int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { GENL_SET_ERR_MSG(info, "invalid addr id or flags"); + err = -EINVAL; goto announce_err; } @@ -282,6 +283,7 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) if (addr_l.id == 0) { NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id"); + err = -EINVAL; goto create_err; } @@ -291,7 +293,14 @@ int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) goto create_err; } - sk = &msk->sk.icsk_inet.sk; + sk = (struct sock *)msk; + + if (!mptcp_pm_addr_families_match(sk, &addr_l, &addr_r)) { + GENL_SET_ERR_MSG(info, "families mismatch"); + err = -EINVAL; + goto create_err; + } + lock_sock(sk); err = __mptcp_subflow_connect(sk, &addr_l, &addr_r); @@ -395,15 +404,17 @@ int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) if (addr_l.family != addr_r.family) { GENL_SET_ERR_MSG(info, "address families do not match"); + err = -EINVAL; goto destroy_err; } if (!addr_l.port || !addr_r.port) { GENL_SET_ERR_MSG(info, "missing local or remote port"); + err = -EINVAL; goto destroy_err; } - sk = &msk->sk.icsk_inet.sk; + sk = (struct sock *)msk; lock_sock(sk); ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r); if (ssk) { diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index b6dc6e260334..8cd6cc67c2c5 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -36,15 +36,6 @@ struct mptcp6_sock { }; #endif -struct mptcp_skb_cb { - u64 map_seq; - u64 end_seq; - u32 offset; - u8 has_rxtstamp:1; -}; - -#define MPTCP_SKB_CB(__skb) ((struct mptcp_skb_cb *)&((__skb)->cb[0])) - enum { MPTCP_CMSG_TS = BIT(0), MPTCP_CMSG_INQ = BIT(1), @@ -107,7 +98,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) struct socket *ssock; int err; - err = mptcp_subflow_create_socket(sk, &ssock); + err = mptcp_subflow_create_socket(sk, sk->sk_family, &ssock); if (err) return err; @@ -200,7 +191,7 @@ static void mptcp_rfree(struct sk_buff *skb) mptcp_rmem_uncharge(sk, len); } -static void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk) +void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk) { skb_orphan(skb); skb->sk = sk; @@ -1602,7 +1593,7 @@ out: __mptcp_check_send_data_fin(sk); } -static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk) +static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk, bool first) { struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sendmsg_info info = { @@ -1611,7 +1602,6 @@ static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk) struct mptcp_data_frag *dfrag; struct sock *xmit_ssk; int len, copied = 0; - bool first = true; info.flags = 0; while ((dfrag = mptcp_send_head(sk))) { @@ -1621,11 +1611,10 @@ static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk) while (len > 0) { int ret = 0; - /* the caller already invoked the packet scheduler, - * check for a different subflow usage only after + /* check for a different subflow usage only after * spooling the first chunk of data */ - xmit_ssk = first ? ssk : mptcp_subflow_get_send(mptcp_sk(sk)); + xmit_ssk = first ? ssk : mptcp_subflow_get_send(msk); if (!xmit_ssk) goto out; if (xmit_ssk != ssk) { @@ -1673,6 +1662,8 @@ static void mptcp_set_nospace(struct sock *sk) set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags); } +static int mptcp_disconnect(struct sock *sk, int flags); + static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg, size_t len, int *copied_syn) { @@ -1683,9 +1674,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh lock_sock(ssk); msg->msg_flags |= MSG_DONTWAIT; msk->connect_flags = O_NONBLOCK; - msk->is_sendmsg = 1; + msk->fastopening = 1; ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL); - msk->is_sendmsg = 0; + msk->fastopening = 0; msg->msg_flags = saved_flags; release_sock(ssk); @@ -1699,6 +1690,8 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh */ if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR) *copied_syn = 0; + } else if (ret && ret != -EINPROGRESS) { + mptcp_disconnect(sk, 0); } return ret; @@ -1713,17 +1706,14 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) int ret = 0; long timeo; - /* we don't support FASTOPEN yet */ - if (msg->msg_flags & MSG_FASTOPEN) - return -EOPNOTSUPP; - /* silently ignore everything else */ - msg->msg_flags &= MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL; + msg->msg_flags &= MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_FASTOPEN; lock_sock(sk); ssock = __mptcp_nmpc_socket(msk); - if (unlikely(ssock && inet_sk(ssock->sk)->defer_connect)) { + if (unlikely(ssock && (inet_sk(ssock->sk)->defer_connect || + msg->msg_flags & MSG_FASTOPEN))) { int copied_syn = 0; ret = mptcp_sendmsg_fastopen(sk, ssock->sk, msg, len, &copied_syn); @@ -2275,7 +2265,7 @@ bool __mptcp_retransmit_pending_data(struct sock *sk) struct mptcp_data_frag *cur, *rtx_head; struct mptcp_sock *msk = mptcp_sk(sk); - if (__mptcp_check_fallback(mptcp_sk(sk))) + if (__mptcp_check_fallback(msk)) return false; if (tcp_rtx_and_write_queues_empty(sk)) @@ -2354,12 +2344,7 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, goto out; } - /* if we are invoked by the msk cleanup code, the subflow is - * already orphaned - */ - if (ssk->sk_socket) - sock_orphan(ssk); - + sock_orphan(ssk); subflow->disposable = 1; /* if ssk hit tcp_done(), tcp_cleanup_ulp() cleared the related ops @@ -2372,8 +2357,9 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, /* otherwise tcp will dispose of the ssk and subflow ctx */ if (ssk->sk_state == TCP_LISTEN) { tcp_set_state(ssk, TCP_CLOSE); - mptcp_subflow_queue_clean(ssk); + mptcp_subflow_queue_clean(sk, ssk); inet_csk_listen_stop(ssk); + mptcp_event_pm_listener(ssk, MPTCP_EVENT_LISTENER_CLOSED); } __tcp_close(ssk, 0); @@ -2456,7 +2442,7 @@ static bool mptcp_check_close_timeout(const struct sock *sk) static void mptcp_check_fastclose(struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow, *tmp; - struct sock *sk = &msk->sk.icsk_inet.sk; + struct sock *sk = (struct sock *)msk; if (likely(!READ_ONCE(msk->rcv_fastclose))) return; @@ -2618,7 +2604,7 @@ static void mptcp_do_fastclose(struct sock *sk) static void mptcp_worker(struct work_struct *work) { struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work); - struct sock *sk = &msk->sk.icsk_inet.sk; + struct sock *sk = (struct sock *)msk; unsigned long fail_tout; int state; @@ -2730,6 +2716,8 @@ static int mptcp_init_sock(struct sock *sk) if (ret) return ret; + set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags); + /* fetch the ca name; do it outside __mptcp_init_sock(), so that clone will * propagate the correct value */ @@ -2940,14 +2928,18 @@ cleanup: if (ssk == msk->first) subflow->fail_tout = 0; - sock_orphan(ssk); + /* detach from the parent socket, but allow data_ready to + * push incoming data into the mptcp stack, to properly ack it + */ + ssk->sk_socket = NULL; + ssk->sk_wq = NULL; unlock_sock_fast(ssk, slow); } sock_orphan(sk); sock_hold(sk); pr_debug("msk=%p state=%d", sk, sk->sk_state); - if (mptcp_sk(sk)->token) + if (msk->token) mptcp_event(MPTCP_EVENT_CLOSED, msk, NULL, GFP_KERNEL); if (sk->sk_state == TCP_CLOSE) { @@ -3001,13 +2993,21 @@ static int mptcp_disconnect(struct sock *sk, int flags) { struct mptcp_sock *msk = mptcp_sk(sk); + /* We are on the fastopen error path. We can't call straight into the + * subflows cleanup code due to lock nesting (we are already under + * msk->firstsocket lock). Do nothing and leave the cleanup to the + * caller. + */ + if (msk->fastopening) + return 0; + inet_sk_state_store(sk, TCP_CLOSE); mptcp_stop_timer(sk); sk_stop_timer(sk, &sk->sk_timer); - if (mptcp_sk(sk)->token) - mptcp_event(MPTCP_EVENT_CLOSED, mptcp_sk(sk), NULL, GFP_KERNEL); + if (msk->token) + mptcp_event(MPTCP_EVENT_CLOSED, msk, NULL, GFP_KERNEL); /* msk->subflow is still intact, the following will not free the first * subflow @@ -3049,7 +3049,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk, struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC); struct mptcp_sock *msk; - u64 ack_seq; if (!nsk) return NULL; @@ -3075,15 +3074,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk, msk->wnd_end = msk->snd_nxt + req->rsk_rcv_wnd; msk->setsockopt_seq = mptcp_sk(sk)->setsockopt_seq; - if (mp_opt->suboptions & OPTIONS_MPTCP_MPC) { - msk->can_ack = true; - msk->remote_key = mp_opt->sndr_key; - mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq); - ack_seq++; - WRITE_ONCE(msk->ack_seq, ack_seq); - atomic64_set(&msk->rcv_wnd_sent, ack_seq); - } - sock_reset_flag(nsk, SOCK_RCU_FREE); /* will be fully established after successful MPC subflow creation */ inet_sk_state_store(nsk, TCP_SYN_RECV); @@ -3218,16 +3208,10 @@ void __mptcp_check_push(struct sock *sk, struct sock *ssk) if (!mptcp_send_head(sk)) return; - if (!sock_owned_by_user(sk)) { - struct sock *xmit_ssk = mptcp_subflow_get_send(mptcp_sk(sk)); - - if (xmit_ssk == ssk) - __mptcp_subflow_push_pending(sk, ssk); - else if (xmit_ssk) - mptcp_subflow_delegate(mptcp_subflow_ctx(xmit_ssk), MPTCP_DELEGATE_SEND); - } else { + if (!sock_owned_by_user(sk)) + __mptcp_subflow_push_pending(sk, ssk, false); + else __set_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->cb_flags); - } } #define MPTCP_FLAGS_PROCESS_CTX_NEED (BIT(MPTCP_PUSH_PENDING) | \ @@ -3318,7 +3302,7 @@ void mptcp_subflow_process_delegated(struct sock *ssk) if (test_bit(MPTCP_DELEGATE_SEND, &subflow->delegated_status)) { mptcp_data_lock(sk); if (!sock_owned_by_user(sk)) - __mptcp_subflow_push_pending(sk, ssk); + __mptcp_subflow_push_pending(sk, ssk, true); else __set_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->cb_flags); mptcp_data_unlock(sk); @@ -3362,7 +3346,6 @@ void mptcp_finish_connect(struct sock *ssk) struct mptcp_subflow_context *subflow; struct mptcp_sock *msk; struct sock *sk; - u64 ack_seq; subflow = mptcp_subflow_ctx(ssk); sk = subflow->conn; @@ -3370,22 +3353,16 @@ void mptcp_finish_connect(struct sock *ssk) pr_debug("msk=%p, token=%u", sk, subflow->token); - mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq); - ack_seq++; - subflow->map_seq = ack_seq; + subflow->map_seq = subflow->iasn; subflow->map_subflow_seq = 1; /* the socket is not connected yet, no msk/subflow ops can access/race * accessing the field below */ - WRITE_ONCE(msk->remote_key, subflow->remote_key); WRITE_ONCE(msk->local_key, subflow->local_key); WRITE_ONCE(msk->write_seq, subflow->idsn + 1); WRITE_ONCE(msk->snd_nxt, msk->write_seq); - WRITE_ONCE(msk->ack_seq, ack_seq); - WRITE_ONCE(msk->can_ack, 1); WRITE_ONCE(msk->snd_una, msk->write_seq); - atomic64_set(&msk->rcv_wnd_sent, ack_seq); mptcp_pm_new_connection(msk, ssk, 0); @@ -3567,7 +3544,7 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) /* if reaching here via the fastopen/sendmsg path, the caller already * acquired the subflow socket lock, too. */ - if (msk->is_sendmsg) + if (msk->fastopening) err = __inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags, 1); else err = inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags); @@ -3683,6 +3660,8 @@ static int mptcp_listen(struct socket *sock, int backlog) if (!err) mptcp_copy_inaddrs(sock->sk, ssock->sk); + mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); + unlock: release_sock(sock->sk); return err; @@ -3707,6 +3686,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, struct mptcp_subflow_context *subflow; struct sock *newsk = newsock->sk; + set_bit(SOCK_CUSTOM_SOCKOPT, &newsock->flags); + lock_sock(newsk); /* PM/worker can now acquire the first subflow socket @@ -3920,12 +3901,6 @@ static const struct proto_ops mptcp_v6_stream_ops = { static struct proto mptcp_v6_prot; -static void mptcp_v6_destroy(struct sock *sk) -{ - mptcp_destroy(sk); - inet6_destroy_sock(sk); -} - static struct inet_protosw mptcp_v6_protosw = { .type = SOCK_STREAM, .protocol = IPPROTO_MPTCP, @@ -3941,7 +3916,6 @@ int __init mptcp_proto_v6_init(void) mptcp_v6_prot = mptcp_prot; strcpy(mptcp_v6_prot.name, "MPTCPv6"); mptcp_v6_prot.slab = NULL; - mptcp_v6_prot.destroy = mptcp_v6_destroy; mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock); err = proto_register(&mptcp_v6_prot, 1); diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 6a09ab99a12d..601469249da8 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -126,6 +126,15 @@ #define MPTCP_CONNECTED 6 #define MPTCP_RESET_SCHEDULER 7 +struct mptcp_skb_cb { + u64 map_seq; + u64 end_seq; + u32 offset; + u8 has_rxtstamp:1; +}; + +#define MPTCP_SKB_CB(__skb) ((struct mptcp_skb_cb *)&((__skb)->cb[0])) + static inline bool before64(__u64 seq1, __u64 seq2) { return (__s64)(seq1 - seq2) < 0; @@ -286,7 +295,7 @@ struct mptcp_sock { u8 recvmsg_inq:1, cork:1, nodelay:1, - is_sendmsg:1; + fastopening:1; int connect_flags; struct work_struct work; struct sk_buff *ooo_last_skb; @@ -467,17 +476,22 @@ struct mptcp_subflow_context { send_fastclose : 1, send_infinite_map : 1, rx_eof : 1, - can_ack : 1, /* only after processing the remote a key */ + remote_key_valid : 1, /* received the peer key from */ disposable : 1, /* ctx can be free at ulp release time */ stale : 1, /* unable to snd/rcv data, do not use for xmit */ local_id_valid : 1, /* local_id is correctly initialized */ - valid_csum_seen : 1; /* at least one csum validated */ + valid_csum_seen : 1, /* at least one csum validated */ + is_mptfo : 1, /* subflow is doing TFO */ + __unused : 8; enum mptcp_data_avail data_avail; u32 remote_nonce; u64 thmac; u32 local_nonce; u32 remote_token; - u8 hmac[MPTCPOPT_HMAC_LEN]; + union { + u8 hmac[MPTCPOPT_HMAC_LEN]; /* MPJ subflow only */ + u64 iasn; /* initial ack sequence number, MPC subflows only */ + }; u8 local_id; u8 remote_id; u8 reset_seen:1; @@ -603,7 +617,7 @@ unsigned int mptcp_stale_loss_cnt(const struct net *net); int mptcp_get_pm_type(const struct net *net); void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk); void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow, - struct mptcp_options_received *mp_opt); + const struct mptcp_options_received *mp_opt); bool __mptcp_retransmit_pending_data(struct sock *sk); void mptcp_check_and_set_pending(struct sock *sk); void __mptcp_push_pending(struct sock *sk, unsigned int flags); @@ -614,11 +628,12 @@ void mptcp_close_ssk(struct sock *sk, struct sock *ssk, struct mptcp_subflow_context *subflow); void __mptcp_subflow_send_ack(struct sock *ssk); void mptcp_subflow_reset(struct sock *ssk); -void mptcp_subflow_queue_clean(struct sock *ssk); +void mptcp_subflow_queue_clean(struct sock *sk, struct sock *ssk); void mptcp_sock_graft(struct sock *sk, struct socket *parent); struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk); bool __mptcp_close(struct sock *sk, long timeout); void mptcp_cancel_work(struct sock *sk); +void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk); bool mptcp_addresses_equal(const struct mptcp_addr_info *a, const struct mptcp_addr_info *b, bool use_port); @@ -626,7 +641,8 @@ bool mptcp_addresses_equal(const struct mptcp_addr_info *a, /* called with sk socket lock held */ int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, const struct mptcp_addr_info *remote); -int mptcp_subflow_create_socket(struct sock *sk, struct socket **new_sock); +int mptcp_subflow_create_socket(struct sock *sk, unsigned short family, + struct socket **new_sock); void mptcp_info2sockaddr(const struct mptcp_addr_info *info, struct sockaddr_storage *addr, unsigned short family); @@ -761,6 +777,9 @@ int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, int mptcp_pm_parse_entry(struct nlattr *attr, struct genl_info *info, bool require_family, struct mptcp_pm_addr_entry *entry); +bool mptcp_pm_addr_families_match(const struct sock *sk, + const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *rem); void mptcp_pm_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk); void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk); void mptcp_pm_new_connection(struct mptcp_sock *msk, const struct sock *ssk, int server_side); @@ -824,8 +843,15 @@ void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, const struct sock *ssk, gfp_t gfp); void mptcp_event_addr_announced(const struct sock *ssk, const struct mptcp_addr_info *info); void mptcp_event_addr_removed(const struct mptcp_sock *msk, u8 id); +void mptcp_event_pm_listener(const struct sock *ssk, + enum mptcp_event_type event); bool mptcp_userspace_pm_active(const struct mptcp_sock *msk); +void mptcp_fastopen_gen_msk_ackseq(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow, + const struct mptcp_options_received *mp_opt); +void mptcp_fastopen_subflow_synack_set_params(struct mptcp_subflow_context *subflow, + struct request_sock *req); + static inline bool mptcp_pm_should_add_signal(struct mptcp_sock *msk) { return READ_ONCE(msk->pm.addr_signal) & diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index c7cb68c725b2..d4b1e6ec1b36 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -559,7 +559,10 @@ static bool mptcp_supported_sockopt(int level, int optname) case TCP_NOTSENT_LOWAT: case TCP_TX_DELAY: case TCP_INQ: + case TCP_FASTOPEN: case TCP_FASTOPEN_CONNECT: + case TCP_FASTOPEN_KEY: + case TCP_FASTOPEN_NO_COOKIE: return true; } @@ -568,9 +571,6 @@ static bool mptcp_supported_sockopt(int level, int optname) /* TCP_REPAIR, TCP_REPAIR_QUEUE, TCP_QUEUE_SEQ, TCP_REPAIR_OPTIONS, * TCP_REPAIR_WINDOW are not supported, better avoid this mess */ - /* TCP_FASTOPEN_KEY, TCP_FASTOPEN, TCP_FASTOPEN_NO_COOKIE, - * are not supported fastopen is currently unsupported - */ } return false; } @@ -740,7 +740,7 @@ static int mptcp_setsockopt_v4_set_tos(struct mptcp_sock *msk, int optname, } release_sock(sk); - return err; + return 0; } static int mptcp_setsockopt_v4(struct mptcp_sock *msk, int optname, @@ -757,29 +757,17 @@ static int mptcp_setsockopt_v4(struct mptcp_sock *msk, int optname, return -EOPNOTSUPP; } -static int mptcp_setsockopt_sol_tcp_defer(struct mptcp_sock *msk, sockptr_t optval, - unsigned int optlen) -{ - struct socket *listener; - - listener = __mptcp_nmpc_socket(msk); - if (!listener) - return 0; /* TCP_DEFER_ACCEPT does not fail */ - - return tcp_setsockopt(listener->sk, SOL_TCP, TCP_DEFER_ACCEPT, optval, optlen); -} - -static int mptcp_setsockopt_sol_tcp_fastopen_connect(struct mptcp_sock *msk, sockptr_t optval, - unsigned int optlen) +static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct socket *sock; - /* Limit to first subflow */ + /* Limit to first subflow, before the connection establishment */ sock = __mptcp_nmpc_socket(msk); if (!sock) return -EINVAL; - return tcp_setsockopt(sock->sk, SOL_TCP, TCP_FASTOPEN_CONNECT, optval, optlen); + return tcp_setsockopt(sock->sk, level, optname, optval, optlen); } static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, @@ -809,9 +797,15 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, case TCP_NODELAY: return mptcp_setsockopt_sol_tcp_nodelay(msk, optval, optlen); case TCP_DEFER_ACCEPT: - return mptcp_setsockopt_sol_tcp_defer(msk, optval, optlen); + /* See tcp.c: TCP_DEFER_ACCEPT does not fail */ + mptcp_setsockopt_first_sf_only(msk, SOL_TCP, optname, optval, optlen); + return 0; + case TCP_FASTOPEN: case TCP_FASTOPEN_CONNECT: - return mptcp_setsockopt_sol_tcp_fastopen_connect(msk, optval, optlen); + case TCP_FASTOPEN_KEY: + case TCP_FASTOPEN_NO_COOKIE: + return mptcp_setsockopt_first_sf_only(msk, SOL_TCP, optname, + optval, optlen); } return -EOPNOTSUPP; @@ -994,7 +988,7 @@ static int mptcp_getsockopt_tcpinfo(struct mptcp_sock *msk, char __user *optval, int __user *optlen) { struct mptcp_subflow_context *subflow; - struct sock *sk = &msk->sk.icsk_inet.sk; + struct sock *sk = (struct sock *)msk; unsigned int sfcount = 0, copied = 0; struct mptcp_subflow_data sfd; char __user *infoptr; @@ -1085,8 +1079,8 @@ static void mptcp_get_sub_addrs(const struct sock *sk, struct mptcp_subflow_addr static int mptcp_getsockopt_subflow_addrs(struct mptcp_sock *msk, char __user *optval, int __user *optlen) { - struct sock *sk = &msk->sk.icsk_inet.sk; struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; unsigned int sfcount = 0, copied = 0; struct mptcp_subflow_data sfd; char __user *addrptr; @@ -1173,7 +1167,10 @@ static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname, case TCP_INFO: case TCP_CC_INFO: case TCP_DEFER_ACCEPT: + case TCP_FASTOPEN: case TCP_FASTOPEN_CONNECT: + case TCP_FASTOPEN_KEY: + case TCP_FASTOPEN_NO_COOKIE: return mptcp_getsockopt_first_sf_only(msk, SOL_TCP, optname, optval, optlen); case TCP_INQ: diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c index 02a54d59697b..ec54413fb31f 100644 --- a/net/mptcp/subflow.c +++ b/net/mptcp/subflow.c @@ -45,7 +45,6 @@ static void subflow_req_destructor(struct request_sock *req) sock_put((struct sock *)subflow_req->msk); mptcp_token_destroy_request(req); - tcp_request_sock_ops.destructor(req); } static void subflow_generate_hmac(u64 key1, u64 key2, u32 nonce1, u32 nonce2, @@ -307,7 +306,48 @@ static struct dst_entry *subflow_v4_route_req(const struct sock *sk, return NULL; } +static void subflow_prep_synack(const struct sock *sk, struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct inet_request_sock *ireq = inet_rsk(req); + + /* clear tstamp_ok, as needed depending on cookie */ + if (foc && foc->len > -1) + ireq->tstamp_ok = 0; + + if (synack_type == TCP_SYNACK_FASTOPEN) + mptcp_fastopen_subflow_synack_set_params(subflow, req); +} + +static int subflow_v4_send_synack(const struct sock *sk, struct dst_entry *dst, + struct flowi *fl, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + subflow_prep_synack(sk, req, foc, synack_type); + + return tcp_request_sock_ipv4_ops.send_synack(sk, dst, fl, req, foc, + synack_type, syn_skb); +} + #if IS_ENABLED(CONFIG_MPTCP_IPV6) +static int subflow_v6_send_synack(const struct sock *sk, struct dst_entry *dst, + struct flowi *fl, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + subflow_prep_synack(sk, req, foc, synack_type); + + return tcp_request_sock_ipv6_ops.send_synack(sk, dst, fl, req, foc, + synack_type, syn_skb); +} + static struct dst_entry *subflow_v6_route_req(const struct sock *sk, struct sk_buff *skb, struct flowi *fl, @@ -392,11 +432,33 @@ static void mptcp_set_connected(struct sock *sk) mptcp_data_unlock(sk); } +static void subflow_set_remote_key(struct mptcp_sock *msk, + struct mptcp_subflow_context *subflow, + const struct mptcp_options_received *mp_opt) +{ + /* active MPC subflow will reach here multiple times: + * at subflow_finish_connect() time and at 4th ack time + */ + if (subflow->remote_key_valid) + return; + + subflow->remote_key_valid = 1; + subflow->remote_key = mp_opt->sndr_key; + mptcp_crypto_key_sha(subflow->remote_key, NULL, &subflow->iasn); + subflow->iasn++; + + WRITE_ONCE(msk->remote_key, subflow->remote_key); + WRITE_ONCE(msk->ack_seq, subflow->iasn); + WRITE_ONCE(msk->can_ack, true); + atomic64_set(&msk->rcv_wnd_sent, subflow->iasn); +} + static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) { struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_options_received mp_opt; struct sock *parent = subflow->conn; + struct mptcp_sock *msk; subflow->icsk_af_ops->sk_rx_dst_set(sk, skb); @@ -404,6 +466,7 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) if (subflow->conn_finished) return; + msk = mptcp_sk(parent); mptcp_propagate_sndbuf(parent, sk); subflow->rel_write_seq = 1; subflow->conn_finished = 1; @@ -416,19 +479,16 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEACTIVEFALLBACK); mptcp_do_fallback(sk); - pr_fallback(mptcp_sk(subflow->conn)); + pr_fallback(msk); goto fallback; } if (mp_opt.suboptions & OPTION_MPTCP_CSUMREQD) - WRITE_ONCE(mptcp_sk(parent)->csum_enabled, true); + WRITE_ONCE(msk->csum_enabled, true); if (mp_opt.deny_join_id0) - WRITE_ONCE(mptcp_sk(parent)->pm.remote_deny_join_id0, true); + WRITE_ONCE(msk->pm.remote_deny_join_id0, true); subflow->mp_capable = 1; - subflow->can_ack = 1; - subflow->remote_key = mp_opt.sndr_key; - pr_debug("subflow=%p, remote_key=%llu", subflow, - subflow->remote_key); + subflow_set_remote_key(msk, subflow, &mp_opt); MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEACTIVEACK); mptcp_finish_connect(sk); mptcp_set_connected(parent); @@ -466,7 +526,7 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) subflow->mp_join = 1; MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_JOINSYNACKRX); - if (subflow_use_different_dport(mptcp_sk(parent), sk)) { + if (subflow_use_different_dport(msk, sk)) { pr_debug("synack inet_dport=%d %d", ntohs(inet_sk(sk)->inet_dport), ntohs(inet_sk(parent)->inet_dport)); @@ -474,7 +534,7 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) } } else if (mptcp_check_fallback(sk)) { fallback: - mptcp_rcv_space_init(mptcp_sk(parent), sk); + mptcp_rcv_space_init(msk, sk); mptcp_set_connected(parent); } return; @@ -529,7 +589,7 @@ static int subflow_v6_rebuild_header(struct sock *sk) } #endif -struct request_sock_ops mptcp_subflow_request_sock_ops; +static struct request_sock_ops mptcp_subflow_v4_request_sock_ops __ro_after_init; static struct tcp_request_sock_ops subflow_request_sock_ipv4_ops __ro_after_init; static int subflow_v4_conn_request(struct sock *sk, struct sk_buff *skb) @@ -542,7 +602,7 @@ static int subflow_v4_conn_request(struct sock *sk, struct sk_buff *skb) if (skb_rtable(skb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) goto drop; - return tcp_conn_request(&mptcp_subflow_request_sock_ops, + return tcp_conn_request(&mptcp_subflow_v4_request_sock_ops, &subflow_request_sock_ipv4_ops, sk, skb); drop: @@ -550,7 +610,14 @@ drop: return 0; } +static void subflow_v4_req_destructor(struct request_sock *req) +{ + subflow_req_destructor(req); + tcp_request_sock_ops.destructor(req); +} + #if IS_ENABLED(CONFIG_MPTCP_IPV6) +static struct request_sock_ops mptcp_subflow_v6_request_sock_ops __ro_after_init; static struct tcp_request_sock_ops subflow_request_sock_ipv6_ops __ro_after_init; static struct inet_connection_sock_af_ops subflow_v6_specific __ro_after_init; static struct inet_connection_sock_af_ops subflow_v6m_specific __ro_after_init; @@ -573,15 +640,36 @@ static int subflow_v6_conn_request(struct sock *sk, struct sk_buff *skb) return 0; } - return tcp_conn_request(&mptcp_subflow_request_sock_ops, + return tcp_conn_request(&mptcp_subflow_v6_request_sock_ops, &subflow_request_sock_ipv6_ops, sk, skb); drop: tcp_listendrop(sk); return 0; /* don't send reset */ } + +static void subflow_v6_req_destructor(struct request_sock *req) +{ + subflow_req_destructor(req); + tcp6_request_sock_ops.destructor(req); +} #endif +struct request_sock *mptcp_subflow_reqsk_alloc(const struct request_sock_ops *ops, + struct sock *sk_listener, + bool attach_listener) +{ + if (ops->family == AF_INET) + ops = &mptcp_subflow_v4_request_sock_ops; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (ops->family == AF_INET6) + ops = &mptcp_subflow_v6_request_sock_ops; +#endif + + return inet_reqsk_alloc(ops, sk_listener, attach_listener); +} +EXPORT_SYMBOL(mptcp_subflow_reqsk_alloc); + /* validate hmac received in third ACK */ static bool subflow_hmac_valid(const struct request_sock *req, const struct mptcp_options_received *mp_opt) @@ -637,14 +725,16 @@ static void subflow_drop_ctx(struct sock *ssk) } void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow, - struct mptcp_options_received *mp_opt) + const struct mptcp_options_received *mp_opt) { struct mptcp_sock *msk = mptcp_sk(subflow->conn); - subflow->remote_key = mp_opt->sndr_key; + subflow_set_remote_key(msk, subflow, mp_opt); subflow->fully_established = 1; - subflow->can_ack = 1; WRITE_ONCE(msk->fully_established, true); + + if (subflow->is_mptfo) + mptcp_fastopen_gen_msk_ackseq(msk, subflow, mp_opt); } static struct sock *subflow_syn_recv_sock(const struct sock *sk, @@ -760,7 +850,7 @@ create_child: /* with OoO packets we can reach here without ingress * mpc option */ - if (mp_opt.suboptions & OPTIONS_MPTCP_MPC) + if (mp_opt.suboptions & OPTION_MPTCP_MPC_ACK) mptcp_subflow_fully_established(ctx, &mp_opt); } else if (ctx->mp_join) { struct mptcp_sock *owner; @@ -1198,16 +1288,8 @@ static bool subflow_check_data_avail(struct sock *ssk) if (WARN_ON_ONCE(!skb)) goto no_data; - /* if msk lacks the remote key, this subflow must provide an - * MP_CAPABLE-based mapping - */ - if (unlikely(!READ_ONCE(msk->can_ack))) { - if (!subflow->mpc_map) - goto fallback; - WRITE_ONCE(msk->remote_key, subflow->remote_key); - WRITE_ONCE(msk->ack_seq, subflow->map_seq); - WRITE_ONCE(msk->can_ack, true); - } + if (unlikely(!READ_ONCE(msk->can_ack))) + goto fallback; old_ack = READ_ONCE(msk->ack_seq); ack_seq = mptcp_subflow_get_mapped_dsn(subflow); @@ -1465,7 +1547,7 @@ int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, if (!mptcp_is_fully_established(sk)) goto err_out; - err = mptcp_subflow_create_socket(sk, &sf); + err = mptcp_subflow_create_socket(sk, loc->family, &sf); if (err) goto err_out; @@ -1480,6 +1562,7 @@ int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, mptcp_pm_get_flags_and_ifindex_by_id(msk, local_id, &flags, &ifindex); + subflow->remote_key_valid = 1; subflow->remote_key = msk->remote_key; subflow->local_key = msk->local_key; subflow->token = msk->token; @@ -1577,7 +1660,9 @@ static void mptcp_subflow_ops_undo_override(struct sock *ssk) #endif ssk->sk_prot = &tcp_prot; } -int mptcp_subflow_create_socket(struct sock *sk, struct socket **new_sock) + +int mptcp_subflow_create_socket(struct sock *sk, unsigned short family, + struct socket **new_sock) { struct mptcp_subflow_context *subflow; struct net *net = sock_net(sk); @@ -1590,8 +1675,7 @@ int mptcp_subflow_create_socket(struct sock *sk, struct socket **new_sock) if (unlikely(!sk->sk_socket)) return -EINVAL; - err = sock_create_kern(net, sk->sk_family, SOCK_STREAM, IPPROTO_TCP, - &sf); + err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf); if (err) return err; @@ -1602,7 +1686,9 @@ int mptcp_subflow_create_socket(struct sock *sk, struct socket **new_sock) /* kernel sockets do not by default acquire net ref, but TCP timer * needs it. + * Update ns_tracker to current stack trace and refcounted tracker. */ + __netns_tracker_free(net, &sf->sk->ns_tracker, false); sf->sk->sk_net_refcnt = 1; get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL); sock_inuse_add(net, 1); @@ -1706,7 +1792,7 @@ static void subflow_state_change(struct sock *sk) } } -void mptcp_subflow_queue_clean(struct sock *listener_ssk) +void mptcp_subflow_queue_clean(struct sock *listener_sk, struct sock *listener_ssk) { struct request_sock_queue *queue = &inet_csk(listener_ssk)->icsk_accept_queue; struct mptcp_sock *msk, *next, *head = NULL; @@ -1745,18 +1831,33 @@ void mptcp_subflow_queue_clean(struct sock *listener_ssk) for (msk = head; msk; msk = next) { struct sock *sk = (struct sock *)msk; - bool slow, do_cancel_work; + bool do_cancel_work; sock_hold(sk); - slow = lock_sock_fast_nested(sk); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); next = msk->dl_next; msk->first = NULL; msk->dl_next = NULL; do_cancel_work = __mptcp_close(sk, 0); - unlock_sock_fast(sk, slow); - if (do_cancel_work) + release_sock(sk); + if (do_cancel_work) { + /* lockdep will report a false positive ABBA deadlock + * between cancel_work_sync and the listener socket. + * The involved locks belong to different sockets WRT + * the existing AB chain. + * Using a per socket key is problematic as key + * deregistration requires process context and must be + * performed at socket disposal time, in atomic + * context. + * Just tell lockdep to consider the listener socket + * released here. + */ + mutex_release(&listener_sk->sk_lock.dep_map, _RET_IP_); mptcp_cancel_work(sk); + mutex_acquire(&listener_sk->sk_lock.dep_map, + SINGLE_DEPTH_NESTING, 0, _RET_IP_); + } sock_put(sk); } @@ -1871,6 +1972,7 @@ static void subflow_ulp_clone(const struct request_sock *req, new_ctx->ssn_offset = subflow_req->ssn_offset; new_ctx->mp_join = 1; new_ctx->fully_established = 1; + new_ctx->remote_key_valid = 1; new_ctx->backup = subflow_req->backup; new_ctx->remote_id = subflow_req->remote_id; new_ctx->token = subflow_req->token; @@ -1904,7 +2006,6 @@ static struct tcp_ulp_ops subflow_ulp_ops __read_mostly = { static int subflow_ops_init(struct request_sock_ops *subflow_ops) { subflow_ops->obj_size = sizeof(struct mptcp_subflow_request_sock); - subflow_ops->slab_name = "request_sock_subflow"; subflow_ops->slab = kmem_cache_create(subflow_ops->slab_name, subflow_ops->obj_size, 0, @@ -1914,19 +2015,21 @@ static int subflow_ops_init(struct request_sock_ops *subflow_ops) if (!subflow_ops->slab) return -ENOMEM; - subflow_ops->destructor = subflow_req_destructor; - return 0; } void __init mptcp_subflow_init(void) { - mptcp_subflow_request_sock_ops = tcp_request_sock_ops; - if (subflow_ops_init(&mptcp_subflow_request_sock_ops) != 0) - panic("MPTCP: failed to init subflow request sock ops\n"); + mptcp_subflow_v4_request_sock_ops = tcp_request_sock_ops; + mptcp_subflow_v4_request_sock_ops.slab_name = "request_sock_subflow_v4"; + mptcp_subflow_v4_request_sock_ops.destructor = subflow_v4_req_destructor; + + if (subflow_ops_init(&mptcp_subflow_v4_request_sock_ops) != 0) + panic("MPTCP: failed to init subflow v4 request sock ops\n"); subflow_request_sock_ipv4_ops = tcp_request_sock_ipv4_ops; subflow_request_sock_ipv4_ops.route_req = subflow_v4_route_req; + subflow_request_sock_ipv4_ops.send_synack = subflow_v4_send_synack; subflow_specific = ipv4_specific; subflow_specific.conn_request = subflow_v4_conn_request; @@ -1938,8 +2041,23 @@ void __init mptcp_subflow_init(void) tcp_prot_override.release_cb = tcp_release_cb_override; #if IS_ENABLED(CONFIG_MPTCP_IPV6) + /* In struct mptcp_subflow_request_sock, we assume the TCP request sock + * structures for v4 and v6 have the same size. It should not changed in + * the future but better to make sure to be warned if it is no longer + * the case. + */ + BUILD_BUG_ON(sizeof(struct tcp_request_sock) != sizeof(struct tcp6_request_sock)); + + mptcp_subflow_v6_request_sock_ops = tcp6_request_sock_ops; + mptcp_subflow_v6_request_sock_ops.slab_name = "request_sock_subflow_v6"; + mptcp_subflow_v6_request_sock_ops.destructor = subflow_v6_req_destructor; + + if (subflow_ops_init(&mptcp_subflow_v6_request_sock_ops) != 0) + panic("MPTCP: failed to init subflow v6 request sock ops\n"); + subflow_request_sock_ipv6_ops = tcp_request_sock_ipv6_ops; subflow_request_sock_ipv6_ops.route_req = subflow_v6_route_req; + subflow_request_sock_ipv6_ops.send_synack = subflow_v6_send_synack; subflow_v6_specific = ipv6_specific; subflow_v6_specific.conn_request = subflow_v6_conn_request; diff --git a/net/mptcp/token.c b/net/mptcp/token.c index f52ee7b26aed..65430f314a68 100644 --- a/net/mptcp/token.c +++ b/net/mptcp/token.c @@ -287,8 +287,8 @@ EXPORT_SYMBOL_GPL(mptcp_token_get_sock); * This function returns the first mptcp connection structure found inside the * token container starting from the specified position, or NULL. * - * On successful iteration, the iterator is move to the next position and the - * the acquires a reference to the returned socket. + * On successful iteration, the iterator is moved to the next position and + * a reference to the returned socket is acquired. */ struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, long *s_num) diff --git a/net/ncsi/ncsi-cmd.c b/net/ncsi/ncsi-cmd.c index dda8b76b7798..fd2236ee9a79 100644 --- a/net/ncsi/ncsi-cmd.c +++ b/net/ncsi/ncsi-cmd.c @@ -228,7 +228,8 @@ static int ncsi_cmd_handler_oem(struct sk_buff *skb, len += max(payload, padding_bytes); cmd = skb_put_zero(skb, len); - memcpy(&cmd->mfr_id, nca->data, nca->payload); + unsafe_memcpy(&cmd->mfr_id, nca->data, nca->payload, + /* skb allocated with enough to load the payload */); ncsi_cmd_build_header(&cmd->cmd.common, nca); return 0; diff --git a/net/netfilter/Kconfig b/net/netfilter/Kconfig index 4b8d04640ff3..f71b41c7ce2f 100644 --- a/net/netfilter/Kconfig +++ b/net/netfilter/Kconfig @@ -459,6 +459,9 @@ config NF_NAT_REDIRECT config NF_NAT_MASQUERADE bool +config NF_NAT_OVS + bool + config NETFILTER_SYNPROXY tristate @@ -568,12 +571,6 @@ config NFT_TUNNEL This option adds the "tunnel" expression that you can use to set tunneling policies. -config NFT_OBJREF - tristate "Netfilter nf_tables stateful object reference module" - help - This option adds the "objref" expression that allows you to refer to - stateful objects, such as counters and quotas. - config NFT_QUEUE depends on NETFILTER_NETLINK_QUEUE tristate "Netfilter nf_tables queue module" diff --git a/net/netfilter/Makefile b/net/netfilter/Makefile index 0f060d100880..3754eb06fb41 100644 --- a/net/netfilter/Makefile +++ b/net/netfilter/Makefile @@ -59,6 +59,7 @@ obj-$(CONFIG_NF_LOG_SYSLOG) += nf_log_syslog.o obj-$(CONFIG_NF_NAT) += nf_nat.o nf_nat-$(CONFIG_NF_NAT_REDIRECT) += nf_nat_redirect.o nf_nat-$(CONFIG_NF_NAT_MASQUERADE) += nf_nat_masquerade.o +nf_nat-$(CONFIG_NF_NAT_OVS) += nf_nat_ovs.o ifeq ($(CONFIG_NF_NAT),m) nf_nat-$(CONFIG_DEBUG_INFO_BTF_MODULES) += nf_nat_bpf.o @@ -86,7 +87,8 @@ nf_tables-objs := nf_tables_core.o nf_tables_api.o nft_chain_filter.o \ nf_tables_trace.o nft_immediate.o nft_cmp.o nft_range.o \ nft_bitwise.o nft_byteorder.o nft_payload.o nft_lookup.o \ nft_dynset.o nft_meta.o nft_rt.o nft_exthdr.o nft_last.o \ - nft_counter.o nft_chain_route.o nf_tables_offload.o \ + nft_counter.o nft_objref.o nft_inner.o \ + nft_chain_route.o nf_tables_offload.o \ nft_set_hash.o nft_set_bitmap.o nft_set_rbtree.o \ nft_set_pipapo.o @@ -104,7 +106,6 @@ obj-$(CONFIG_NFT_CT) += nft_ct.o obj-$(CONFIG_NFT_FLOW_OFFLOAD) += nft_flow_offload.o obj-$(CONFIG_NFT_LIMIT) += nft_limit.o obj-$(CONFIG_NFT_NAT) += nft_nat.o -obj-$(CONFIG_NFT_OBJREF) += nft_objref.o obj-$(CONFIG_NFT_QUEUE) += nft_queue.o obj-$(CONFIG_NFT_QUOTA) += nft_quota.o obj-$(CONFIG_NFT_REJECT) += nft_reject.o diff --git a/net/netfilter/ipset/ip_set_bitmap_ip.c b/net/netfilter/ipset/ip_set_bitmap_ip.c index a8ce04a4bb72..e4fa00abde6a 100644 --- a/net/netfilter/ipset/ip_set_bitmap_ip.c +++ b/net/netfilter/ipset/ip_set_bitmap_ip.c @@ -308,8 +308,8 @@ bitmap_ip_create(struct net *net, struct ip_set *set, struct nlattr *tb[], return -IPSET_ERR_BITMAP_RANGE; pr_debug("mask_bits %u, netmask %u\n", mask_bits, netmask); - hosts = 2 << (32 - netmask - 1); - elements = 2 << (netmask - mask_bits - 1); + hosts = 2U << (32 - netmask - 1); + elements = 2UL << (netmask - mask_bits - 1); } if (elements > IPSET_BITMAP_MAX_RANGE + 1) return -IPSET_ERR_BITMAP_RANGE_SIZE; diff --git a/net/netfilter/ipset/ip_set_core.c b/net/netfilter/ipset/ip_set_core.c index e7ba5b6dd2b7..46ebee9400da 100644 --- a/net/netfilter/ipset/ip_set_core.c +++ b/net/netfilter/ipset/ip_set_core.c @@ -1698,9 +1698,10 @@ call_ad(struct net *net, struct sock *ctnl, struct sk_buff *skb, ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried); ip_set_unlock(set); retried = true; - } while (ret == -EAGAIN && - set->variant->resize && - (ret = set->variant->resize(set, retried)) == 0); + } while (ret == -ERANGE || + (ret == -EAGAIN && + set->variant->resize && + (ret = set->variant->resize(set, retried)) == 0)); if (!ret || (ret == -IPSET_ERR_EXIST && eexist)) return 0; diff --git a/net/netfilter/ipset/ip_set_hash_gen.h b/net/netfilter/ipset/ip_set_hash_gen.h index 3adc291d9ce1..7c2399541771 100644 --- a/net/netfilter/ipset/ip_set_hash_gen.h +++ b/net/netfilter/ipset/ip_set_hash_gen.h @@ -159,6 +159,17 @@ htable_size(u8 hbits) (SET_WITH_TIMEOUT(set) && \ ip_set_timeout_expired(ext_timeout(d, set))) +#if defined(IP_SET_HASH_WITH_NETMASK) || defined(IP_SET_HASH_WITH_BITMASK) +static const union nf_inet_addr onesmask = { + .all[0] = 0xffffffff, + .all[1] = 0xffffffff, + .all[2] = 0xffffffff, + .all[3] = 0xffffffff +}; + +static const union nf_inet_addr zeromask = {}; +#endif + #endif /* _IP_SET_HASH_GEN_H */ #ifndef MTYPE @@ -283,8 +294,9 @@ struct htype { u32 markmask; /* markmask value for mark mask to store */ #endif u8 bucketsize; /* max elements in an array block */ -#ifdef IP_SET_HASH_WITH_NETMASK +#if defined(IP_SET_HASH_WITH_NETMASK) || defined(IP_SET_HASH_WITH_BITMASK) u8 netmask; /* netmask value for subnets to store */ + union nf_inet_addr bitmask; /* stores bitmask */ #endif struct list_head ad; /* Resize add|del backlist */ struct mtype_elem next; /* temporary storage for uadd */ @@ -459,8 +471,8 @@ mtype_same_set(const struct ip_set *a, const struct ip_set *b) /* Resizing changes htable_bits, so we ignore it */ return x->maxelem == y->maxelem && a->timeout == b->timeout && -#ifdef IP_SET_HASH_WITH_NETMASK - x->netmask == y->netmask && +#if defined(IP_SET_HASH_WITH_NETMASK) || defined(IP_SET_HASH_WITH_BITMASK) + nf_inet_addr_cmp(&x->bitmask, &y->bitmask) && #endif #ifdef IP_SET_HASH_WITH_MARKMASK x->markmask == y->markmask && @@ -916,7 +928,7 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext, #ifdef IP_SET_HASH_WITH_MULTI if (h->bucketsize >= AHASH_MAX_TUNED) goto set_full; - else if (h->bucketsize < multi) + else if (h->bucketsize <= multi) h->bucketsize += AHASH_INIT_SIZE; #endif if (n->size >= AHASH_MAX(h)) { @@ -1264,9 +1276,21 @@ mtype_head(struct ip_set *set, struct sk_buff *skb) htonl(jhash_size(htable_bits))) || nla_put_net32(skb, IPSET_ATTR_MAXELEM, htonl(h->maxelem))) goto nla_put_failure; +#ifdef IP_SET_HASH_WITH_BITMASK + /* if netmask is set to anything other than HOST_MASK we know that the user supplied netmask + * and not bitmask. These two are mutually exclusive. */ + if (h->netmask == HOST_MASK && !nf_inet_addr_cmp(&onesmask, &h->bitmask)) { + if (set->family == NFPROTO_IPV4) { + if (nla_put_ipaddr4(skb, IPSET_ATTR_BITMASK, h->bitmask.ip)) + goto nla_put_failure; + } else if (set->family == NFPROTO_IPV6) { + if (nla_put_ipaddr6(skb, IPSET_ATTR_BITMASK, &h->bitmask.in6)) + goto nla_put_failure; + } + } +#endif #ifdef IP_SET_HASH_WITH_NETMASK - if (h->netmask != HOST_MASK && - nla_put_u8(skb, IPSET_ATTR_NETMASK, h->netmask)) + if (h->netmask != HOST_MASK && nla_put_u8(skb, IPSET_ATTR_NETMASK, h->netmask)) goto nla_put_failure; #endif #ifdef IP_SET_HASH_WITH_MARKMASK @@ -1429,8 +1453,10 @@ IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set, u32 markmask; #endif u8 hbits; -#ifdef IP_SET_HASH_WITH_NETMASK - u8 netmask; +#if defined(IP_SET_HASH_WITH_NETMASK) || defined(IP_SET_HASH_WITH_BITMASK) + int ret __attribute__((unused)) = 0; + u8 netmask = set->family == NFPROTO_IPV4 ? 32 : 128; + union nf_inet_addr bitmask = onesmask; #endif size_t hsize; struct htype *h; @@ -1468,7 +1494,6 @@ IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set, #endif #ifdef IP_SET_HASH_WITH_NETMASK - netmask = set->family == NFPROTO_IPV4 ? 32 : 128; if (tb[IPSET_ATTR_NETMASK]) { netmask = nla_get_u8(tb[IPSET_ATTR_NETMASK]); @@ -1476,6 +1501,33 @@ IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set, (set->family == NFPROTO_IPV6 && netmask > 128) || netmask == 0) return -IPSET_ERR_INVALID_NETMASK; + + /* we convert netmask to bitmask and store it */ + if (set->family == NFPROTO_IPV4) + bitmask.ip = ip_set_netmask(netmask); + else + ip6_netmask(&bitmask, netmask); + } +#endif + +#ifdef IP_SET_HASH_WITH_BITMASK + if (tb[IPSET_ATTR_BITMASK]) { + /* bitmask and netmask do the same thing, allow only one of these options */ + if (tb[IPSET_ATTR_NETMASK]) + return -IPSET_ERR_BITMASK_NETMASK_EXCL; + + if (set->family == NFPROTO_IPV4) { + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_BITMASK], &bitmask.ip); + if (ret || !bitmask.ip) + return -IPSET_ERR_INVALID_NETMASK; + } else if (set->family == NFPROTO_IPV6) { + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_BITMASK], &bitmask); + if (ret || ipv6_addr_any(&bitmask.in6)) + return -IPSET_ERR_INVALID_NETMASK; + } + + if (nf_inet_addr_cmp(&bitmask, &zeromask)) + return -IPSET_ERR_INVALID_NETMASK; } #endif @@ -1518,7 +1570,8 @@ IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set, for (i = 0; i < ahash_numof_locks(hbits); i++) spin_lock_init(&t->hregion[i].lock); h->maxelem = maxelem; -#ifdef IP_SET_HASH_WITH_NETMASK +#if defined(IP_SET_HASH_WITH_NETMASK) || defined(IP_SET_HASH_WITH_BITMASK) + h->bitmask = bitmask; h->netmask = netmask; #endif #ifdef IP_SET_HASH_WITH_MARKMASK diff --git a/net/netfilter/ipset/ip_set_hash_ip.c b/net/netfilter/ipset/ip_set_hash_ip.c index dd30c03d5a23..c9f4e3859663 100644 --- a/net/netfilter/ipset/ip_set_hash_ip.c +++ b/net/netfilter/ipset/ip_set_hash_ip.c @@ -24,7 +24,8 @@ /* 2 Comments support */ /* 3 Forceadd support */ /* 4 skbinfo support */ -#define IPSET_TYPE_REV_MAX 5 /* bucketsize, initval support */ +/* 5 bucketsize, initval support */ +#define IPSET_TYPE_REV_MAX 6 /* bitmask support */ MODULE_LICENSE("GPL"); MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>"); @@ -34,6 +35,7 @@ MODULE_ALIAS("ip_set_hash:ip"); /* Type specific function prefix */ #define HTYPE hash_ip #define IP_SET_HASH_WITH_NETMASK +#define IP_SET_HASH_WITH_BITMASK /* IPv4 variant */ @@ -86,7 +88,7 @@ hash_ip4_kadt(struct ip_set *set, const struct sk_buff *skb, __be32 ip; ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &ip); - ip &= ip_set_netmask(h->netmask); + ip &= h->bitmask.ip; if (ip == 0) return -EINVAL; @@ -98,11 +100,11 @@ static int hash_ip4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_ip4 *h = set->data; + struct hash_ip4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ip4_elem e = { 0 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip = 0, ip_to = 0, hosts; + u32 ip = 0, ip_to = 0, hosts, i = 0; int ret = 0; if (tb[IPSET_ATTR_LINENO]) @@ -119,7 +121,7 @@ hash_ip4_uadt(struct ip_set *set, struct nlattr *tb[], if (ret) return ret; - ip &= ip_set_hostmask(h->netmask); + ip &= ntohl(h->bitmask.ip); e.ip = htonl(ip); if (e.ip == 0) return -IPSET_ERR_HASH_ELEM; @@ -147,22 +149,20 @@ hash_ip4_uadt(struct ip_set *set, struct nlattr *tb[], hosts = h->netmask == 32 ? 1 : 2 << (32 - h->netmask - 1); - /* 64bit division is not allowed on 32bit */ - if (((u64)ip_to - ip + 1) >> (32 - h->netmask) > IPSET_MAX_RANGE) - return -ERANGE; - - if (retried) { + if (retried) ip = ntohl(h->next.ip); + for (; ip <= ip_to; i++) { e.ip = htonl(ip); - } - for (; ip <= ip_to;) { + if (i > IPSET_MAX_RANGE) { + hash_ip4_data_next(&h->next, &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) return ret; ip += hosts; - e.ip = htonl(ip); - if (e.ip == 0) + if (ip == 0) return 0; ret = 0; @@ -187,12 +187,6 @@ hash_ip6_data_equal(const struct hash_ip6_elem *ip1, return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6); } -static void -hash_ip6_netmask(union nf_inet_addr *ip, u8 prefix) -{ - ip6_netmask(ip, prefix); -} - static bool hash_ip6_data_list(struct sk_buff *skb, const struct hash_ip6_elem *e) { @@ -229,7 +223,7 @@ hash_ip6_kadt(struct ip_set *set, const struct sk_buff *skb, struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); - hash_ip6_netmask(&e.ip, h->netmask); + nf_inet_addr_mask_inplace(&e.ip, &h->bitmask); if (ipv6_addr_any(&e.ip.in6)) return -EINVAL; @@ -268,7 +262,7 @@ hash_ip6_uadt(struct ip_set *set, struct nlattr *tb[], if (ret) return ret; - hash_ip6_netmask(&e.ip, h->netmask); + nf_inet_addr_mask_inplace(&e.ip, &h->bitmask); if (ipv6_addr_any(&e.ip.in6)) return -IPSET_ERR_HASH_ELEM; @@ -295,6 +289,7 @@ static struct ip_set_type hash_ip_type __read_mostly = { [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, [IPSET_ATTR_NETMASK] = { .type = NLA_U8 }, + [IPSET_ATTR_BITMASK] = { .type = NLA_NESTED }, [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, }, .adt_policy = { diff --git a/net/netfilter/ipset/ip_set_hash_ipmark.c b/net/netfilter/ipset/ip_set_hash_ipmark.c index 153de3457423..a22ec1a6f6ec 100644 --- a/net/netfilter/ipset/ip_set_hash_ipmark.c +++ b/net/netfilter/ipset/ip_set_hash_ipmark.c @@ -97,11 +97,11 @@ static int hash_ipmark4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_ipmark4 *h = set->data; + struct hash_ipmark4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipmark4_elem e = { }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip, ip_to = 0; + u32 ip, ip_to = 0, i = 0; int ret; if (tb[IPSET_ATTR_LINENO]) @@ -148,13 +148,14 @@ hash_ipmark4_uadt(struct ip_set *set, struct nlattr *tb[], ip_set_mask_from_to(ip, ip_to, cidr); } - if (((u64)ip_to - ip + 1) > IPSET_MAX_RANGE) - return -ERANGE; - if (retried) ip = ntohl(h->next.ip); - for (; ip <= ip_to; ip++) { + for (; ip <= ip_to; ip++, i++) { e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_ipmark4_data_next(&h->next, &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) diff --git a/net/netfilter/ipset/ip_set_hash_ipport.c b/net/netfilter/ipset/ip_set_hash_ipport.c index 7303138e46be..e977b5a9c48d 100644 --- a/net/netfilter/ipset/ip_set_hash_ipport.c +++ b/net/netfilter/ipset/ip_set_hash_ipport.c @@ -26,7 +26,8 @@ /* 3 Comments support added */ /* 4 Forceadd support added */ /* 5 skbinfo support added */ -#define IPSET_TYPE_REV_MAX 6 /* bucketsize, initval support added */ +/* 6 bucketsize, initval support added */ +#define IPSET_TYPE_REV_MAX 7 /* bitmask support added */ MODULE_LICENSE("GPL"); MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>"); @@ -35,6 +36,8 @@ MODULE_ALIAS("ip_set_hash:ip,port"); /* Type specific function prefix */ #define HTYPE hash_ipport +#define IP_SET_HASH_WITH_NETMASK +#define IP_SET_HASH_WITH_BITMASK /* IPv4 variant */ @@ -92,12 +95,16 @@ hash_ipport4_kadt(struct ip_set *set, const struct sk_buff *skb, ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipport4_elem e = { .ip = 0 }; struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + const struct MTYPE *h = set->data; if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, &e.port, &e.proto)) return -EINVAL; ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + e.ip &= h->bitmask.ip; + if (e.ip == 0) + return -EINVAL; return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); } @@ -105,11 +112,11 @@ static int hash_ipport4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_ipport4 *h = set->data; + struct hash_ipport4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipport4_elem e = { .ip = 0 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip, ip_to = 0, p = 0, port, port_to; + u32 ip, ip_to = 0, p = 0, port, port_to, i = 0; bool with_ports = false; int ret; @@ -129,6 +136,10 @@ hash_ipport4_uadt(struct ip_set *set, struct nlattr *tb[], if (ret) return ret; + e.ip &= h->bitmask.ip; + if (e.ip == 0) + return -EINVAL; + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); if (tb[IPSET_ATTR_PROTO]) { @@ -173,17 +184,18 @@ hash_ipport4_uadt(struct ip_set *set, struct nlattr *tb[], swap(port, port_to); } - if (((u64)ip_to - ip + 1)*(port_to - port + 1) > IPSET_MAX_RANGE) - return -ERANGE; - if (retried) ip = ntohl(h->next.ip); for (; ip <= ip_to; ip++) { p = retried && ip == ntohl(h->next.ip) ? ntohs(h->next.port) : port; - for (; p <= port_to; p++) { + for (; p <= port_to; p++, i++) { e.ip = htonl(ip); e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_ipport4_data_next(&h->next, &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) @@ -253,12 +265,17 @@ hash_ipport6_kadt(struct ip_set *set, const struct sk_buff *skb, ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipport6_elem e = { .ip = { .all = { 0 } } }; struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + const struct MTYPE *h = set->data; if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, &e.port, &e.proto)) return -EINVAL; ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + nf_inet_addr_mask_inplace(&e.ip, &h->bitmask); + if (ipv6_addr_any(&e.ip.in6)) + return -EINVAL; + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); } @@ -298,6 +315,10 @@ hash_ipport6_uadt(struct ip_set *set, struct nlattr *tb[], if (ret) return ret; + nf_inet_addr_mask_inplace(&e.ip, &h->bitmask); + if (ipv6_addr_any(&e.ip.in6)) + return -EINVAL; + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); if (tb[IPSET_ATTR_PROTO]) { @@ -356,6 +377,8 @@ static struct ip_set_type hash_ipport_type __read_mostly = { [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_NETMASK] = { .type = NLA_U8 }, + [IPSET_ATTR_BITMASK] = { .type = NLA_NESTED }, }, .adt_policy = { [IPSET_ATTR_IP] = { .type = NLA_NESTED }, diff --git a/net/netfilter/ipset/ip_set_hash_ipportip.c b/net/netfilter/ipset/ip_set_hash_ipportip.c index 334fb1ad0e86..39a01934b153 100644 --- a/net/netfilter/ipset/ip_set_hash_ipportip.c +++ b/net/netfilter/ipset/ip_set_hash_ipportip.c @@ -108,11 +108,11 @@ static int hash_ipportip4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_ipportip4 *h = set->data; + struct hash_ipportip4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipportip4_elem e = { .ip = 0 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip, ip_to = 0, p = 0, port, port_to; + u32 ip, ip_to = 0, p = 0, port, port_to, i = 0; bool with_ports = false; int ret; @@ -180,17 +180,18 @@ hash_ipportip4_uadt(struct ip_set *set, struct nlattr *tb[], swap(port, port_to); } - if (((u64)ip_to - ip + 1)*(port_to - port + 1) > IPSET_MAX_RANGE) - return -ERANGE; - if (retried) ip = ntohl(h->next.ip); for (; ip <= ip_to; ip++) { p = retried && ip == ntohl(h->next.ip) ? ntohs(h->next.port) : port; - for (; p <= port_to; p++) { + for (; p <= port_to; p++, i++) { e.ip = htonl(ip); e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_ipportip4_data_next(&h->next, &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) diff --git a/net/netfilter/ipset/ip_set_hash_ipportnet.c b/net/netfilter/ipset/ip_set_hash_ipportnet.c index 7df94f437f60..5c6de605a9fb 100644 --- a/net/netfilter/ipset/ip_set_hash_ipportnet.c +++ b/net/netfilter/ipset/ip_set_hash_ipportnet.c @@ -160,12 +160,12 @@ static int hash_ipportnet4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_ipportnet4 *h = set->data; + struct hash_ipportnet4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_ipportnet4_elem e = { .cidr = HOST_MASK - 1 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); u32 ip = 0, ip_to = 0, p = 0, port, port_to; - u32 ip2_from = 0, ip2_to = 0, ip2; + u32 ip2_from = 0, ip2_to = 0, ip2, i = 0; bool with_ports = false; u8 cidr; int ret; @@ -253,9 +253,6 @@ hash_ipportnet4_uadt(struct ip_set *set, struct nlattr *tb[], swap(port, port_to); } - if (((u64)ip_to - ip + 1)*(port_to - port + 1) > IPSET_MAX_RANGE) - return -ERANGE; - ip2_to = ip2_from; if (tb[IPSET_ATTR_IP2_TO]) { ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2_TO], &ip2_to); @@ -282,9 +279,15 @@ hash_ipportnet4_uadt(struct ip_set *set, struct nlattr *tb[], for (; p <= port_to; p++) { e.port = htons(p); do { + i++; e.ip2 = htonl(ip2); ip2 = ip_set_range_to_cidr(ip2, ip2_to, &cidr); e.cidr = cidr - 1; + if (i > IPSET_MAX_RANGE) { + hash_ipportnet4_data_next(&h->next, + &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) diff --git a/net/netfilter/ipset/ip_set_hash_net.c b/net/netfilter/ipset/ip_set_hash_net.c index 1422739d9aa2..ce0a9ce5a91f 100644 --- a/net/netfilter/ipset/ip_set_hash_net.c +++ b/net/netfilter/ipset/ip_set_hash_net.c @@ -136,11 +136,11 @@ static int hash_net4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_net4 *h = set->data; + struct hash_net4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_net4_elem e = { .cidr = HOST_MASK }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip = 0, ip_to = 0, ipn, n = 0; + u32 ip = 0, ip_to = 0, i = 0; int ret; if (tb[IPSET_ATTR_LINENO]) @@ -188,19 +188,16 @@ hash_net4_uadt(struct ip_set *set, struct nlattr *tb[], if (ip + UINT_MAX == ip_to) return -IPSET_ERR_HASH_RANGE; } - ipn = ip; - do { - ipn = ip_set_range_to_cidr(ipn, ip_to, &e.cidr); - n++; - } while (ipn++ < ip_to); - - if (n > IPSET_MAX_RANGE) - return -ERANGE; if (retried) ip = ntohl(h->next.ip); do { + i++; e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_net4_data_next(&h->next, &e); + return -ERANGE; + } ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr); ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) diff --git a/net/netfilter/ipset/ip_set_hash_netiface.c b/net/netfilter/ipset/ip_set_hash_netiface.c index 9810f5bf63f5..031073286236 100644 --- a/net/netfilter/ipset/ip_set_hash_netiface.c +++ b/net/netfilter/ipset/ip_set_hash_netiface.c @@ -202,7 +202,7 @@ hash_netiface4_uadt(struct ip_set *set, struct nlattr *tb[], ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_netiface4_elem e = { .cidr = HOST_MASK, .elem = 1 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 ip = 0, ip_to = 0, ipn, n = 0; + u32 ip = 0, ip_to = 0, i = 0; int ret; if (tb[IPSET_ATTR_LINENO]) @@ -256,19 +256,16 @@ hash_netiface4_uadt(struct ip_set *set, struct nlattr *tb[], } else { ip_set_mask_from_to(ip, ip_to, e.cidr); } - ipn = ip; - do { - ipn = ip_set_range_to_cidr(ipn, ip_to, &e.cidr); - n++; - } while (ipn++ < ip_to); - - if (n > IPSET_MAX_RANGE) - return -ERANGE; if (retried) ip = ntohl(h->next.ip); do { + i++; e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_netiface4_data_next(&h->next, &e); + return -ERANGE; + } ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr); ret = adtfn(set, &e, &ext, &ext, flags); diff --git a/net/netfilter/ipset/ip_set_hash_netnet.c b/net/netfilter/ipset/ip_set_hash_netnet.c index 3d09eefe998a..8fbe649c9dd3 100644 --- a/net/netfilter/ipset/ip_set_hash_netnet.c +++ b/net/netfilter/ipset/ip_set_hash_netnet.c @@ -23,7 +23,8 @@ #define IPSET_TYPE_REV_MIN 0 /* 1 Forceadd support added */ /* 2 skbinfo support added */ -#define IPSET_TYPE_REV_MAX 3 /* bucketsize, initval support added */ +/* 3 bucketsize, initval support added */ +#define IPSET_TYPE_REV_MAX 4 /* bitmask support added */ MODULE_LICENSE("GPL"); MODULE_AUTHOR("Oliver Smith <oliver@8.c.9.b.0.7.4.0.1.0.0.2.ip6.arpa>"); @@ -33,6 +34,8 @@ MODULE_ALIAS("ip_set_hash:net,net"); /* Type specific function prefix */ #define HTYPE hash_netnet #define IP_SET_HASH_WITH_NETS +#define IP_SET_HASH_WITH_NETMASK +#define IP_SET_HASH_WITH_BITMASK #define IPSET_NET_COUNT 2 /* IPv4 variants */ @@ -153,8 +156,8 @@ hash_netnet4_kadt(struct ip_set *set, const struct sk_buff *skb, ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip[0]); ip4addrptr(skb, opt->flags & IPSET_DIM_TWO_SRC, &e.ip[1]); - e.ip[0] &= ip_set_netmask(e.cidr[0]); - e.ip[1] &= ip_set_netmask(e.cidr[1]); + e.ip[0] &= (ip_set_netmask(e.cidr[0]) & h->bitmask.ip); + e.ip[1] &= (ip_set_netmask(e.cidr[1]) & h->bitmask.ip); return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); } @@ -163,13 +166,12 @@ static int hash_netnet4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_netnet4 *h = set->data; + struct hash_netnet4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_netnet4_elem e = { }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); u32 ip = 0, ip_to = 0; - u32 ip2 = 0, ip2_from = 0, ip2_to = 0, ipn; - u64 n = 0, m = 0; + u32 ip2 = 0, ip2_from = 0, ip2_to = 0, i = 0; int ret; if (tb[IPSET_ATTR_LINENO]) @@ -213,8 +215,8 @@ hash_netnet4_uadt(struct ip_set *set, struct nlattr *tb[], if (adt == IPSET_TEST || !(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_IP2_TO])) { - e.ip[0] = htonl(ip & ip_set_hostmask(e.cidr[0])); - e.ip[1] = htonl(ip2_from & ip_set_hostmask(e.cidr[1])); + e.ip[0] = htonl(ip & ntohl(h->bitmask.ip) & ip_set_hostmask(e.cidr[0])); + e.ip[1] = htonl(ip2_from & ntohl(h->bitmask.ip) & ip_set_hostmask(e.cidr[1])); ret = adtfn(set, &e, &ext, &ext, flags); return ip_set_enomatch(ret, flags, adt, set) ? -ret : ip_set_eexist(ret, flags) ? 0 : ret; @@ -245,19 +247,6 @@ hash_netnet4_uadt(struct ip_set *set, struct nlattr *tb[], } else { ip_set_mask_from_to(ip2_from, ip2_to, e.cidr[1]); } - ipn = ip; - do { - ipn = ip_set_range_to_cidr(ipn, ip_to, &e.cidr[0]); - n++; - } while (ipn++ < ip_to); - ipn = ip2_from; - do { - ipn = ip_set_range_to_cidr(ipn, ip2_to, &e.cidr[1]); - m++; - } while (ipn++ < ip2_to); - - if (n*m > IPSET_MAX_RANGE) - return -ERANGE; if (retried) { ip = ntohl(h->next.ip[0]); @@ -270,7 +259,12 @@ hash_netnet4_uadt(struct ip_set *set, struct nlattr *tb[], e.ip[0] = htonl(ip); ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr[0]); do { + i++; e.ip[1] = htonl(ip2); + if (i > IPSET_MAX_RANGE) { + hash_netnet4_data_next(&h->next, &e); + return -ERANGE; + } ip2 = ip_set_range_to_cidr(ip2, ip2_to, &e.cidr[1]); ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) @@ -404,6 +398,11 @@ hash_netnet6_kadt(struct ip_set *set, const struct sk_buff *skb, ip6_netmask(&e.ip[0], e.cidr[0]); ip6_netmask(&e.ip[1], e.cidr[1]); + nf_inet_addr_mask_inplace(&e.ip[0], &h->bitmask); + nf_inet_addr_mask_inplace(&e.ip[1], &h->bitmask); + if (e.cidr[0] == HOST_MASK && ipv6_addr_any(&e.ip[0].in6)) + return -EINVAL; + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); } @@ -414,6 +413,7 @@ hash_netnet6_uadt(struct ip_set *set, struct nlattr *tb[], ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_netnet6_elem e = { }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + const struct hash_netnet6 *h = set->data; int ret; if (tb[IPSET_ATTR_LINENO]) @@ -453,6 +453,11 @@ hash_netnet6_uadt(struct ip_set *set, struct nlattr *tb[], ip6_netmask(&e.ip[0], e.cidr[0]); ip6_netmask(&e.ip[1], e.cidr[1]); + nf_inet_addr_mask_inplace(&e.ip[0], &h->bitmask); + nf_inet_addr_mask_inplace(&e.ip[1], &h->bitmask); + if (e.cidr[0] == HOST_MASK && ipv6_addr_any(&e.ip[0].in6)) + return -IPSET_ERR_HASH_ELEM; + if (tb[IPSET_ATTR_CADT_FLAGS]) { u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); @@ -484,6 +489,8 @@ static struct ip_set_type hash_netnet_type __read_mostly = { [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_NETMASK] = { .type = NLA_U8 }, + [IPSET_ATTR_BITMASK] = { .type = NLA_NESTED }, }, .adt_policy = { [IPSET_ATTR_IP] = { .type = NLA_NESTED }, diff --git a/net/netfilter/ipset/ip_set_hash_netport.c b/net/netfilter/ipset/ip_set_hash_netport.c index 09cf72eb37f8..d1a0628df4ef 100644 --- a/net/netfilter/ipset/ip_set_hash_netport.c +++ b/net/netfilter/ipset/ip_set_hash_netport.c @@ -154,12 +154,11 @@ static int hash_netport4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_netport4 *h = set->data; + struct hash_netport4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_netport4_elem e = { .cidr = HOST_MASK - 1 }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); - u32 port, port_to, p = 0, ip = 0, ip_to = 0, ipn; - u64 n = 0; + u32 port, port_to, p = 0, ip = 0, ip_to = 0, i = 0; bool with_ports = false; u8 cidr; int ret; @@ -236,14 +235,6 @@ hash_netport4_uadt(struct ip_set *set, struct nlattr *tb[], } else { ip_set_mask_from_to(ip, ip_to, e.cidr + 1); } - ipn = ip; - do { - ipn = ip_set_range_to_cidr(ipn, ip_to, &cidr); - n++; - } while (ipn++ < ip_to); - - if (n*(port_to - port + 1) > IPSET_MAX_RANGE) - return -ERANGE; if (retried) { ip = ntohl(h->next.ip); @@ -255,8 +246,12 @@ hash_netport4_uadt(struct ip_set *set, struct nlattr *tb[], e.ip = htonl(ip); ip = ip_set_range_to_cidr(ip, ip_to, &cidr); e.cidr = cidr - 1; - for (; p <= port_to; p++) { + for (; p <= port_to; p++, i++) { e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_netport4_data_next(&h->next, &e); + return -ERANGE; + } ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) return ret; diff --git a/net/netfilter/ipset/ip_set_hash_netportnet.c b/net/netfilter/ipset/ip_set_hash_netportnet.c index 19bcdb3141f6..005a7ce87217 100644 --- a/net/netfilter/ipset/ip_set_hash_netportnet.c +++ b/net/netfilter/ipset/ip_set_hash_netportnet.c @@ -173,17 +173,26 @@ hash_netportnet4_kadt(struct ip_set *set, const struct sk_buff *skb, return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); } +static u32 +hash_netportnet4_range_to_cidr(u32 from, u32 to, u8 *cidr) +{ + if (from == 0 && to == UINT_MAX) { + *cidr = 0; + return to; + } + return ip_set_range_to_cidr(from, to, cidr); +} + static int hash_netportnet4_uadt(struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) { - const struct hash_netportnet4 *h = set->data; + struct hash_netportnet4 *h = set->data; ipset_adtfn adtfn = set->variant->adt[adt]; struct hash_netportnet4_elem e = { }; struct ip_set_ext ext = IP_SET_INIT_UEXT(set); u32 ip = 0, ip_to = 0, p = 0, port, port_to; - u32 ip2_from = 0, ip2_to = 0, ip2, ipn; - u64 n = 0, m = 0; + u32 ip2_from = 0, ip2_to = 0, ip2, i = 0; bool with_ports = false; int ret; @@ -285,19 +294,6 @@ hash_netportnet4_uadt(struct ip_set *set, struct nlattr *tb[], } else { ip_set_mask_from_to(ip2_from, ip2_to, e.cidr[1]); } - ipn = ip; - do { - ipn = ip_set_range_to_cidr(ipn, ip_to, &e.cidr[0]); - n++; - } while (ipn++ < ip_to); - ipn = ip2_from; - do { - ipn = ip_set_range_to_cidr(ipn, ip2_to, &e.cidr[1]); - m++; - } while (ipn++ < ip2_to); - - if (n*m*(port_to - port + 1) > IPSET_MAX_RANGE) - return -ERANGE; if (retried) { ip = ntohl(h->next.ip[0]); @@ -310,13 +306,19 @@ hash_netportnet4_uadt(struct ip_set *set, struct nlattr *tb[], do { e.ip[0] = htonl(ip); - ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr[0]); + ip = hash_netportnet4_range_to_cidr(ip, ip_to, &e.cidr[0]); for (; p <= port_to; p++) { e.port = htons(p); do { + i++; e.ip[1] = htonl(ip2); - ip2 = ip_set_range_to_cidr(ip2, ip2_to, - &e.cidr[1]); + if (i > IPSET_MAX_RANGE) { + hash_netportnet4_data_next(&h->next, + &e); + return -ERANGE; + } + ip2 = hash_netportnet4_range_to_cidr(ip2, + ip2_to, &e.cidr[1]); ret = adtfn(set, &e, &ext, &ext, flags); if (ret && !ip_set_eexist(ret, flags)) return ret; diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c index 5a67f7966574..e162636525cf 100644 --- a/net/netfilter/ipset/ip_set_list_set.c +++ b/net/netfilter/ipset/ip_set_list_set.c @@ -427,7 +427,7 @@ list_set_destroy(struct ip_set *set) struct set_elem *e, *n; if (SET_WITH_TIMEOUT(set)) - del_timer_sync(&map->gc); + timer_shutdown_sync(&map->gc); list_for_each_entry_safe(e, n, &map->members, list) { list_del(&e->list); diff --git a/net/netfilter/ipvs/ip_vs_core.c b/net/netfilter/ipvs/ip_vs_core.c index 51ad557a525b..2fcc26507d69 100644 --- a/net/netfilter/ipvs/ip_vs_core.c +++ b/net/netfilter/ipvs/ip_vs_core.c @@ -132,21 +132,21 @@ ip_vs_in_stats(struct ip_vs_conn *cp, struct sk_buff *skb) s = this_cpu_ptr(dest->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.inpkts++; - s->cnt.inbytes += skb->len; + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); u64_stats_update_end(&s->syncp); svc = rcu_dereference(dest->svc); s = this_cpu_ptr(svc->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.inpkts++; - s->cnt.inbytes += skb->len; + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); u64_stats_update_end(&s->syncp); - s = this_cpu_ptr(ipvs->tot_stats.cpustats); + s = this_cpu_ptr(ipvs->tot_stats->s.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.inpkts++; - s->cnt.inbytes += skb->len; + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); u64_stats_update_end(&s->syncp); local_bh_enable(); @@ -168,21 +168,21 @@ ip_vs_out_stats(struct ip_vs_conn *cp, struct sk_buff *skb) s = this_cpu_ptr(dest->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.outpkts++; - s->cnt.outbytes += skb->len; + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); u64_stats_update_end(&s->syncp); svc = rcu_dereference(dest->svc); s = this_cpu_ptr(svc->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.outpkts++; - s->cnt.outbytes += skb->len; + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); u64_stats_update_end(&s->syncp); - s = this_cpu_ptr(ipvs->tot_stats.cpustats); + s = this_cpu_ptr(ipvs->tot_stats->s.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.outpkts++; - s->cnt.outbytes += skb->len; + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); u64_stats_update_end(&s->syncp); local_bh_enable(); @@ -200,17 +200,17 @@ ip_vs_conn_stats(struct ip_vs_conn *cp, struct ip_vs_service *svc) s = this_cpu_ptr(cp->dest->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.conns++; + u64_stats_inc(&s->cnt.conns); u64_stats_update_end(&s->syncp); s = this_cpu_ptr(svc->stats.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.conns++; + u64_stats_inc(&s->cnt.conns); u64_stats_update_end(&s->syncp); - s = this_cpu_ptr(ipvs->tot_stats.cpustats); + s = this_cpu_ptr(ipvs->tot_stats->s.cpustats); u64_stats_update_begin(&s->syncp); - s->cnt.conns++; + u64_stats_inc(&s->cnt.conns); u64_stats_update_end(&s->syncp); local_bh_enable(); @@ -2448,6 +2448,10 @@ static void __exit ip_vs_cleanup(void) ip_vs_conn_cleanup(); ip_vs_protocol_cleanup(); ip_vs_control_cleanup(); + /* common rcu_barrier() used by: + * - ip_vs_control_cleanup() + */ + rcu_barrier(); pr_info("ipvs unloaded.\n"); } diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c index 988222fff9f0..2a5ed71c82c3 100644 --- a/net/netfilter/ipvs/ip_vs_ctl.c +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -49,8 +49,7 @@ MODULE_ALIAS_GENL_FAMILY(IPVS_GENL_NAME); -/* semaphore for IPVS sockopts. And, [gs]etsockopt may sleep. */ -static DEFINE_MUTEX(__ip_vs_mutex); +DEFINE_MUTEX(__ip_vs_mutex); /* Serialize configuration with sockopt/netlink */ /* sysctl variables */ @@ -241,6 +240,47 @@ static void defense_work_handler(struct work_struct *work) } #endif +static void est_reload_work_handler(struct work_struct *work) +{ + struct netns_ipvs *ipvs = + container_of(work, struct netns_ipvs, est_reload_work.work); + int genid_done = atomic_read(&ipvs->est_genid_done); + unsigned long delay = HZ / 10; /* repeat startups after failure */ + bool repeat = false; + int genid; + int id; + + mutex_lock(&ipvs->est_mutex); + genid = atomic_read(&ipvs->est_genid); + for (id = 0; id < ipvs->est_kt_count; id++) { + struct ip_vs_est_kt_data *kd = ipvs->est_kt_arr[id]; + + /* netns clean up started, abort delayed work */ + if (!ipvs->enable) + goto unlock; + if (!kd) + continue; + /* New config ? Stop kthread tasks */ + if (genid != genid_done) + ip_vs_est_kthread_stop(kd); + if (!kd->task && !ip_vs_est_stopped(ipvs)) { + /* Do not start kthreads above 0 in calc phase */ + if ((!id || !ipvs->est_calc_phase) && + ip_vs_est_kthread_start(ipvs, kd) < 0) + repeat = true; + } + } + + atomic_set(&ipvs->est_genid_done, genid); + + if (repeat) + queue_delayed_work(system_long_wq, &ipvs->est_reload_work, + delay); + +unlock: + mutex_unlock(&ipvs->est_mutex); +} + int ip_vs_use_count_inc(void) { @@ -471,7 +511,7 @@ __ip_vs_bind_svc(struct ip_vs_dest *dest, struct ip_vs_service *svc) static void ip_vs_service_free(struct ip_vs_service *svc) { - free_percpu(svc->stats.cpustats); + ip_vs_stats_release(&svc->stats); kfree(svc); } @@ -483,17 +523,14 @@ static void ip_vs_service_rcu_free(struct rcu_head *head) ip_vs_service_free(svc); } -static void __ip_vs_svc_put(struct ip_vs_service *svc, bool do_delay) +static void __ip_vs_svc_put(struct ip_vs_service *svc) { if (atomic_dec_and_test(&svc->refcnt)) { IP_VS_DBG_BUF(3, "Removing service %u/%s:%u\n", svc->fwmark, IP_VS_DBG_ADDR(svc->af, &svc->addr), ntohs(svc->port)); - if (do_delay) - call_rcu(&svc->rcu_head, ip_vs_service_rcu_free); - else - ip_vs_service_free(svc); + call_rcu(&svc->rcu_head, ip_vs_service_rcu_free); } } @@ -780,14 +817,22 @@ out: return dest; } +static void ip_vs_dest_rcu_free(struct rcu_head *head) +{ + struct ip_vs_dest *dest; + + dest = container_of(head, struct ip_vs_dest, rcu_head); + ip_vs_stats_release(&dest->stats); + ip_vs_dest_put_and_free(dest); +} + static void ip_vs_dest_free(struct ip_vs_dest *dest) { struct ip_vs_service *svc = rcu_dereference_protected(dest->svc, 1); __ip_vs_dst_cache_reset(dest); - __ip_vs_svc_put(svc, false); - free_percpu(dest->stats.cpustats); - ip_vs_dest_put_and_free(dest); + __ip_vs_svc_put(svc); + call_rcu(&dest->rcu_head, ip_vs_dest_rcu_free); } /* @@ -811,12 +856,22 @@ static void ip_vs_trash_cleanup(struct netns_ipvs *ipvs) } } +static void ip_vs_stats_rcu_free(struct rcu_head *head) +{ + struct ip_vs_stats_rcu *rs = container_of(head, + struct ip_vs_stats_rcu, + rcu_head); + + ip_vs_stats_release(&rs->s); + kfree(rs); +} + static void ip_vs_copy_stats(struct ip_vs_kstats *dst, struct ip_vs_stats *src) { #define IP_VS_SHOW_STATS_COUNTER(c) dst->c = src->kstats.c - src->kstats0.c - spin_lock_bh(&src->lock); + spin_lock(&src->lock); IP_VS_SHOW_STATS_COUNTER(conns); IP_VS_SHOW_STATS_COUNTER(inpkts); @@ -826,7 +881,7 @@ ip_vs_copy_stats(struct ip_vs_kstats *dst, struct ip_vs_stats *src) ip_vs_read_estimator(dst, src); - spin_unlock_bh(&src->lock); + spin_unlock(&src->lock); } static void @@ -847,7 +902,7 @@ ip_vs_export_stats_user(struct ip_vs_stats_user *dst, struct ip_vs_kstats *src) static void ip_vs_zero_stats(struct ip_vs_stats *stats) { - spin_lock_bh(&stats->lock); + spin_lock(&stats->lock); /* get current counters as zero point, rates are zeroed */ @@ -861,7 +916,48 @@ ip_vs_zero_stats(struct ip_vs_stats *stats) ip_vs_zero_estimator(stats); - spin_unlock_bh(&stats->lock); + spin_unlock(&stats->lock); +} + +/* Allocate fields after kzalloc */ +int ip_vs_stats_init_alloc(struct ip_vs_stats *s) +{ + int i; + + spin_lock_init(&s->lock); + s->cpustats = alloc_percpu(struct ip_vs_cpu_stats); + if (!s->cpustats) + return -ENOMEM; + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *cs = per_cpu_ptr(s->cpustats, i); + + u64_stats_init(&cs->syncp); + } + return 0; +} + +struct ip_vs_stats *ip_vs_stats_alloc(void) +{ + struct ip_vs_stats *s = kzalloc(sizeof(*s), GFP_KERNEL); + + if (s && ip_vs_stats_init_alloc(s) >= 0) + return s; + kfree(s); + return NULL; +} + +void ip_vs_stats_release(struct ip_vs_stats *stats) +{ + free_percpu(stats->cpustats); +} + +void ip_vs_stats_free(struct ip_vs_stats *stats) +{ + if (stats) { + ip_vs_stats_release(stats); + kfree(stats); + } } /* @@ -923,7 +1019,7 @@ __ip_vs_update_dest(struct ip_vs_service *svc, struct ip_vs_dest *dest, if (old_svc != svc) { ip_vs_zero_stats(&dest->stats); __ip_vs_bind_svc(dest, svc); - __ip_vs_svc_put(old_svc, true); + __ip_vs_svc_put(old_svc); } } @@ -942,7 +1038,6 @@ __ip_vs_update_dest(struct ip_vs_service *svc, struct ip_vs_dest *dest, spin_unlock_bh(&dest->dst_lock); if (add) { - ip_vs_start_estimator(svc->ipvs, &dest->stats); list_add_rcu(&dest->n_list, &svc->destinations); svc->num_dests++; sched = rcu_dereference_protected(svc->scheduler, 1); @@ -963,14 +1058,13 @@ static int ip_vs_new_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) { struct ip_vs_dest *dest; - unsigned int atype, i; + unsigned int atype; + int ret; EnterFunction(2); #ifdef CONFIG_IP_VS_IPV6 if (udest->af == AF_INET6) { - int ret; - atype = ipv6_addr_type(&udest->addr.in6); if ((!(atype & IPV6_ADDR_UNICAST) || atype & IPV6_ADDR_LINKLOCAL) && @@ -992,15 +1086,13 @@ ip_vs_new_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) if (dest == NULL) return -ENOMEM; - dest->stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); - if (!dest->stats.cpustats) + ret = ip_vs_stats_init_alloc(&dest->stats); + if (ret < 0) goto err_alloc; - for_each_possible_cpu(i) { - struct ip_vs_cpu_stats *ip_vs_dest_stats; - ip_vs_dest_stats = per_cpu_ptr(dest->stats.cpustats, i); - u64_stats_init(&ip_vs_dest_stats->syncp); - } + ret = ip_vs_start_estimator(svc->ipvs, &dest->stats); + if (ret < 0) + goto err_stats; dest->af = udest->af; dest->protocol = svc->protocol; @@ -1017,15 +1109,17 @@ ip_vs_new_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) INIT_HLIST_NODE(&dest->d_list); spin_lock_init(&dest->dst_lock); - spin_lock_init(&dest->stats.lock); __ip_vs_update_dest(svc, dest, udest, 1); LeaveFunction(2); return 0; +err_stats: + ip_vs_stats_release(&dest->stats); + err_alloc: kfree(dest); - return -ENOMEM; + return ret; } @@ -1087,14 +1181,18 @@ ip_vs_add_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) IP_VS_DBG_ADDR(svc->af, &dest->vaddr), ntohs(dest->vport)); + ret = ip_vs_start_estimator(svc->ipvs, &dest->stats); + if (ret < 0) + goto err; __ip_vs_update_dest(svc, dest, udest, 1); - ret = 0; } else { /* * Allocate and initialize the dest structure */ ret = ip_vs_new_dest(svc, udest); } + +err: LeaveFunction(2); return ret; @@ -1284,7 +1382,7 @@ static int ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, struct ip_vs_service **svc_p) { - int ret = 0, i; + int ret = 0; struct ip_vs_scheduler *sched = NULL; struct ip_vs_pe *pe = NULL; struct ip_vs_service *svc = NULL; @@ -1344,18 +1442,9 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, ret = -ENOMEM; goto out_err; } - svc->stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); - if (!svc->stats.cpustats) { - ret = -ENOMEM; + ret = ip_vs_stats_init_alloc(&svc->stats); + if (ret < 0) goto out_err; - } - - for_each_possible_cpu(i) { - struct ip_vs_cpu_stats *ip_vs_stats; - ip_vs_stats = per_cpu_ptr(svc->stats.cpustats, i); - u64_stats_init(&ip_vs_stats->syncp); - } - /* I'm the first user of the service */ atomic_set(&svc->refcnt, 0); @@ -1372,7 +1461,6 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, INIT_LIST_HEAD(&svc->destinations); spin_lock_init(&svc->sched_lock); - spin_lock_init(&svc->stats.lock); /* Bind the scheduler */ if (sched) { @@ -1382,6 +1470,10 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, sched = NULL; } + ret = ip_vs_start_estimator(ipvs, &svc->stats); + if (ret < 0) + goto out_err; + /* Bind the ct retriever */ RCU_INIT_POINTER(svc->pe, pe); pe = NULL; @@ -1394,8 +1486,6 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, if (svc->pe && svc->pe->conn_out) atomic_inc(&ipvs->conn_out_counter); - ip_vs_start_estimator(ipvs, &svc->stats); - /* Count only IPv4 services for old get/setsockopt interface */ if (svc->af == AF_INET) ipvs->num_services++; @@ -1406,8 +1496,15 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, ip_vs_svc_hash(svc); *svc_p = svc; - /* Now there is a service - full throttle */ - ipvs->enable = 1; + + if (!ipvs->enable) { + /* Now there is a service - full throttle */ + ipvs->enable = 1; + + /* Start estimation for first time */ + ip_vs_est_reload_start(ipvs); + } + return 0; @@ -1571,7 +1668,7 @@ static void __ip_vs_del_service(struct ip_vs_service *svc, bool cleanup) /* * Free the service if nobody refers to it */ - __ip_vs_svc_put(svc, true); + __ip_vs_svc_put(svc); /* decrease the module use count */ ip_vs_use_count_dec(); @@ -1761,7 +1858,7 @@ static int ip_vs_zero_all(struct netns_ipvs *ipvs) } } - ip_vs_zero_stats(&ipvs->tot_stats); + ip_vs_zero_stats(&ipvs->tot_stats->s); return 0; } @@ -1843,6 +1940,148 @@ proc_do_sync_ports(struct ctl_table *table, int write, return rc; } +static int ipvs_proc_est_cpumask_set(struct ctl_table *table, void *buffer) +{ + struct netns_ipvs *ipvs = table->extra2; + cpumask_var_t *valp = table->data; + cpumask_var_t newmask; + int ret; + + if (!zalloc_cpumask_var(&newmask, GFP_KERNEL)) + return -ENOMEM; + + ret = cpulist_parse(buffer, newmask); + if (ret) + goto out; + + mutex_lock(&ipvs->est_mutex); + + if (!ipvs->est_cpulist_valid) { + if (!zalloc_cpumask_var(valp, GFP_KERNEL)) { + ret = -ENOMEM; + goto unlock; + } + ipvs->est_cpulist_valid = 1; + } + cpumask_and(newmask, newmask, ¤t->cpus_mask); + cpumask_copy(*valp, newmask); + /* est_max_threads may depend on cpulist size */ + ipvs->est_max_threads = ip_vs_est_max_threads(ipvs); + ipvs->est_calc_phase = 1; + ip_vs_est_reload_start(ipvs); + +unlock: + mutex_unlock(&ipvs->est_mutex); + +out: + free_cpumask_var(newmask); + return ret; +} + +static int ipvs_proc_est_cpumask_get(struct ctl_table *table, void *buffer, + size_t size) +{ + struct netns_ipvs *ipvs = table->extra2; + cpumask_var_t *valp = table->data; + struct cpumask *mask; + int ret; + + mutex_lock(&ipvs->est_mutex); + + if (ipvs->est_cpulist_valid) + mask = *valp; + else + mask = (struct cpumask *)housekeeping_cpumask(HK_TYPE_KTHREAD); + ret = scnprintf(buffer, size, "%*pbl\n", cpumask_pr_args(mask)); + + mutex_unlock(&ipvs->est_mutex); + + return ret; +} + +static int ipvs_proc_est_cpulist(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int ret; + + /* Ignore both read and write(append) if *ppos not 0 */ + if (*ppos || !*lenp) { + *lenp = 0; + return 0; + } + if (write) { + /* proc_sys_call_handler() appends terminator */ + ret = ipvs_proc_est_cpumask_set(table, buffer); + if (ret >= 0) + *ppos += *lenp; + } else { + /* proc_sys_call_handler() allocates 1 byte for terminator */ + ret = ipvs_proc_est_cpumask_get(table, buffer, *lenp + 1); + if (ret >= 0) { + *lenp = ret; + *ppos += *lenp; + ret = 0; + } + } + return ret; +} + +static int ipvs_proc_est_nice(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct netns_ipvs *ipvs = table->extra2; + int *valp = table->data; + int val = *valp; + int ret; + + struct ctl_table tmp_table = { + .data = &val, + .maxlen = sizeof(int), + .mode = table->mode, + }; + + ret = proc_dointvec(&tmp_table, write, buffer, lenp, ppos); + if (write && ret >= 0) { + if (val < MIN_NICE || val > MAX_NICE) { + ret = -EINVAL; + } else { + mutex_lock(&ipvs->est_mutex); + if (*valp != val) { + *valp = val; + ip_vs_est_reload_start(ipvs); + } + mutex_unlock(&ipvs->est_mutex); + } + } + return ret; +} + +static int ipvs_proc_run_estimation(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct netns_ipvs *ipvs = table->extra2; + int *valp = table->data; + int val = *valp; + int ret; + + struct ctl_table tmp_table = { + .data = &val, + .maxlen = sizeof(int), + .mode = table->mode, + }; + + ret = proc_dointvec(&tmp_table, write, buffer, lenp, ppos); + if (write && ret >= 0) { + mutex_lock(&ipvs->est_mutex); + if (*valp != val) { + *valp = val; + ip_vs_est_reload_start(ipvs); + } + mutex_unlock(&ipvs->est_mutex); + } + return ret; +} + /* * IPVS sysctl table (under the /proc/sys/net/ipv4/vs/) * Do not change order or insert new entries without @@ -2017,7 +2256,19 @@ static struct ctl_table vs_vars[] = { .procname = "run_estimation", .maxlen = sizeof(int), .mode = 0644, - .proc_handler = proc_dointvec, + .proc_handler = ipvs_proc_run_estimation, + }, + { + .procname = "est_cpulist", + .maxlen = NR_CPUS, /* unused */ + .mode = 0644, + .proc_handler = ipvs_proc_est_cpulist, + }, + { + .procname = "est_nice", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = ipvs_proc_est_nice, }, #ifdef CONFIG_IP_VS_DEBUG { @@ -2255,7 +2506,7 @@ static int ip_vs_stats_show(struct seq_file *seq, void *v) seq_puts(seq, " Conns Packets Packets Bytes Bytes\n"); - ip_vs_copy_stats(&show, &net_ipvs(net)->tot_stats); + ip_vs_copy_stats(&show, &net_ipvs(net)->tot_stats->s); seq_printf(seq, "%8LX %8LX %8LX %16LX %16LX\n\n", (unsigned long long)show.conns, (unsigned long long)show.inpkts, @@ -2279,7 +2530,7 @@ static int ip_vs_stats_show(struct seq_file *seq, void *v) static int ip_vs_stats_percpu_show(struct seq_file *seq, void *v) { struct net *net = seq_file_single_net(seq); - struct ip_vs_stats *tot_stats = &net_ipvs(net)->tot_stats; + struct ip_vs_stats *tot_stats = &net_ipvs(net)->tot_stats->s; struct ip_vs_cpu_stats __percpu *cpustats = tot_stats->cpustats; struct ip_vs_kstats kstats; int i; @@ -2296,13 +2547,13 @@ static int ip_vs_stats_percpu_show(struct seq_file *seq, void *v) u64 conns, inpkts, outpkts, inbytes, outbytes; do { - start = u64_stats_fetch_begin_irq(&u->syncp); - conns = u->cnt.conns; - inpkts = u->cnt.inpkts; - outpkts = u->cnt.outpkts; - inbytes = u->cnt.inbytes; - outbytes = u->cnt.outbytes; - } while (u64_stats_fetch_retry_irq(&u->syncp, start)); + start = u64_stats_fetch_begin(&u->syncp); + conns = u64_stats_read(&u->cnt.conns); + inpkts = u64_stats_read(&u->cnt.inpkts); + outpkts = u64_stats_read(&u->cnt.outpkts); + inbytes = u64_stats_read(&u->cnt.inbytes); + outbytes = u64_stats_read(&u->cnt.outbytes); + } while (u64_stats_fetch_retry(&u->syncp, start)); seq_printf(seq, "%3X %8LX %8LX %8LX %16LX %16LX\n", i, (u64)conns, (u64)inpkts, @@ -2590,6 +2841,11 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, sockptr_t ptr, unsigned int len) break; case IP_VS_SO_SET_DELDEST: ret = ip_vs_del_dest(svc, &udest); + break; + default: + WARN_ON_ONCE(1); + ret = -EINVAL; + break; } out_unlock: @@ -4027,13 +4283,17 @@ static void ip_vs_genl_unregister(void) static int __net_init ip_vs_control_net_init_sysctl(struct netns_ipvs *ipvs) { struct net *net = ipvs->net; - int idx; struct ctl_table *tbl; + int idx, ret; atomic_set(&ipvs->dropentry, 0); spin_lock_init(&ipvs->dropentry_lock); spin_lock_init(&ipvs->droppacket_lock); spin_lock_init(&ipvs->securetcp_lock); + INIT_DELAYED_WORK(&ipvs->defense_work, defense_work_handler); + INIT_DELAYED_WORK(&ipvs->expire_nodest_conn_work, + expire_nodest_conn_handler); + ipvs->est_stopped = 0; if (!net_eq(net, &init_net)) { tbl = kmemdup(vs_vars, sizeof(vs_vars), GFP_KERNEL); @@ -4094,31 +4354,44 @@ static int __net_init ip_vs_control_net_init_sysctl(struct netns_ipvs *ipvs) tbl[idx++].data = &ipvs->sysctl_schedule_icmp; tbl[idx++].data = &ipvs->sysctl_ignore_tunneled; ipvs->sysctl_run_estimation = 1; + tbl[idx].extra2 = ipvs; tbl[idx++].data = &ipvs->sysctl_run_estimation; + + ipvs->est_cpulist_valid = 0; + tbl[idx].extra2 = ipvs; + tbl[idx++].data = &ipvs->sysctl_est_cpulist; + + ipvs->sysctl_est_nice = IPVS_EST_NICE; + tbl[idx].extra2 = ipvs; + tbl[idx++].data = &ipvs->sysctl_est_nice; + #ifdef CONFIG_IP_VS_DEBUG /* Global sysctls must be ro in non-init netns */ if (!net_eq(net, &init_net)) tbl[idx++].mode = 0444; #endif + ret = -ENOMEM; ipvs->sysctl_hdr = register_net_sysctl(net, "net/ipv4/vs", tbl); - if (ipvs->sysctl_hdr == NULL) { - if (!net_eq(net, &init_net)) - kfree(tbl); - return -ENOMEM; - } - ip_vs_start_estimator(ipvs, &ipvs->tot_stats); + if (!ipvs->sysctl_hdr) + goto err; ipvs->sysctl_tbl = tbl; + + ret = ip_vs_start_estimator(ipvs, &ipvs->tot_stats->s); + if (ret < 0) + goto err; + /* Schedule defense work */ - INIT_DELAYED_WORK(&ipvs->defense_work, defense_work_handler); queue_delayed_work(system_long_wq, &ipvs->defense_work, DEFENSE_TIMER_PERIOD); - /* Init delayed work for expiring no dest conn */ - INIT_DELAYED_WORK(&ipvs->expire_nodest_conn_work, - expire_nodest_conn_handler); - return 0; + +err: + unregister_net_sysctl_table(ipvs->sysctl_hdr); + if (!net_eq(net, &init_net)) + kfree(tbl); + return ret; } static void __net_exit ip_vs_control_net_cleanup_sysctl(struct netns_ipvs *ipvs) @@ -4129,7 +4402,10 @@ static void __net_exit ip_vs_control_net_cleanup_sysctl(struct netns_ipvs *ipvs) cancel_delayed_work_sync(&ipvs->defense_work); cancel_work_sync(&ipvs->defense_work.work); unregister_net_sysctl_table(ipvs->sysctl_hdr); - ip_vs_stop_estimator(ipvs, &ipvs->tot_stats); + ip_vs_stop_estimator(ipvs, &ipvs->tot_stats->s); + + if (ipvs->est_cpulist_valid) + free_cpumask_var(ipvs->sysctl_est_cpulist); if (!net_eq(net, &init_net)) kfree(ipvs->sysctl_tbl); @@ -4151,7 +4427,8 @@ static struct notifier_block ip_vs_dst_notifier = { int __net_init ip_vs_control_net_init(struct netns_ipvs *ipvs) { - int i, idx; + int ret = -ENOMEM; + int idx; /* Initialize rs_table */ for (idx = 0; idx < IP_VS_RTAB_SIZE; idx++) @@ -4164,18 +4441,14 @@ int __net_init ip_vs_control_net_init(struct netns_ipvs *ipvs) atomic_set(&ipvs->nullsvc_counter, 0); atomic_set(&ipvs->conn_out_counter, 0); - /* procfs stats */ - ipvs->tot_stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); - if (!ipvs->tot_stats.cpustats) - return -ENOMEM; - - for_each_possible_cpu(i) { - struct ip_vs_cpu_stats *ipvs_tot_stats; - ipvs_tot_stats = per_cpu_ptr(ipvs->tot_stats.cpustats, i); - u64_stats_init(&ipvs_tot_stats->syncp); - } + INIT_DELAYED_WORK(&ipvs->est_reload_work, est_reload_work_handler); - spin_lock_init(&ipvs->tot_stats.lock); + /* procfs stats */ + ipvs->tot_stats = kzalloc(sizeof(*ipvs->tot_stats), GFP_KERNEL); + if (!ipvs->tot_stats) + goto out; + if (ip_vs_stats_init_alloc(&ipvs->tot_stats->s) < 0) + goto err_tot_stats; #ifdef CONFIG_PROC_FS if (!proc_create_net("ip_vs", 0, ipvs->net->proc_net, @@ -4190,7 +4463,8 @@ int __net_init ip_vs_control_net_init(struct netns_ipvs *ipvs) goto err_percpu; #endif - if (ip_vs_control_net_init_sysctl(ipvs)) + ret = ip_vs_control_net_init_sysctl(ipvs); + if (ret < 0) goto err; return 0; @@ -4207,20 +4481,26 @@ err_stats: err_vs: #endif - free_percpu(ipvs->tot_stats.cpustats); - return -ENOMEM; + ip_vs_stats_release(&ipvs->tot_stats->s); + +err_tot_stats: + kfree(ipvs->tot_stats); + +out: + return ret; } void __net_exit ip_vs_control_net_cleanup(struct netns_ipvs *ipvs) { ip_vs_trash_cleanup(ipvs); ip_vs_control_net_cleanup_sysctl(ipvs); + cancel_delayed_work_sync(&ipvs->est_reload_work); #ifdef CONFIG_PROC_FS remove_proc_entry("ip_vs_stats_percpu", ipvs->net->proc_net); remove_proc_entry("ip_vs_stats", ipvs->net->proc_net); remove_proc_entry("ip_vs", ipvs->net->proc_net); #endif - free_percpu(ipvs->tot_stats.cpustats); + call_rcu(&ipvs->tot_stats->rcu_head, ip_vs_stats_rcu_free); } int __init ip_vs_register_nl_ioctl(void) @@ -4280,5 +4560,6 @@ void ip_vs_control_cleanup(void) { EnterFunction(2); unregister_netdevice_notifier(&ip_vs_dst_notifier); + /* relying on common rcu_barrier() in ip_vs_cleanup() */ LeaveFunction(2); } diff --git a/net/netfilter/ipvs/ip_vs_est.c b/net/netfilter/ipvs/ip_vs_est.c index 9a1a7af6a186..ce2a1549b304 100644 --- a/net/netfilter/ipvs/ip_vs_est.c +++ b/net/netfilter/ipvs/ip_vs_est.c @@ -30,9 +30,6 @@ long interval, it is easy to implement a user level daemon which periodically reads those statistical counters and measure rate. - Currently, the measurement is activated by slow timer handler. Hope - this measurement will not introduce too much load. - We measure rate during the last 8 seconds every 2 seconds: avgrate = avgrate*(1-W) + rate*W @@ -47,68 +44,79 @@ to 32-bit values for conns, packets, bps, cps and pps. * A lot of code is taken from net/core/gen_estimator.c - */ - -/* - * Make a summary from each cpu + KEY POINTS: + - cpustats counters are updated per-cpu in SoftIRQ context with BH disabled + - kthreads read the cpustats to update the estimators (svcs, dests, total) + - the states of estimators can be read (get stats) or modified (zero stats) + from processes + + KTHREADS: + - estimators are added initially to est_temp_list and later kthread 0 + distributes them to one or many kthreads for estimation + - kthread contexts are created and attached to array + - the kthread tasks are started when first service is added, before that + the total stats are not estimated + - when configuration (cpulist/nice) is changed, the tasks are restarted + by work (est_reload_work) + - kthread tasks are stopped while the cpulist is empty + - the kthread context holds lists with estimators (chains) which are + processed every 2 seconds + - as estimators can be added dynamically and in bursts, we try to spread + them to multiple chains which are estimated at different time + - on start, kthread 0 enters calculation phase to determine the chain limits + and the limit of estimators per kthread + - est_add_ktid: ktid where to add new ests, can point to empty slot where + we should add kt data */ -static void ip_vs_read_cpu_stats(struct ip_vs_kstats *sum, - struct ip_vs_cpu_stats __percpu *stats) -{ - int i; - bool add = false; - - for_each_possible_cpu(i) { - struct ip_vs_cpu_stats *s = per_cpu_ptr(stats, i); - unsigned int start; - u64 conns, inpkts, outpkts, inbytes, outbytes; - if (add) { - do { - start = u64_stats_fetch_begin(&s->syncp); - conns = s->cnt.conns; - inpkts = s->cnt.inpkts; - outpkts = s->cnt.outpkts; - inbytes = s->cnt.inbytes; - outbytes = s->cnt.outbytes; - } while (u64_stats_fetch_retry(&s->syncp, start)); - sum->conns += conns; - sum->inpkts += inpkts; - sum->outpkts += outpkts; - sum->inbytes += inbytes; - sum->outbytes += outbytes; - } else { - add = true; - do { - start = u64_stats_fetch_begin(&s->syncp); - sum->conns = s->cnt.conns; - sum->inpkts = s->cnt.inpkts; - sum->outpkts = s->cnt.outpkts; - sum->inbytes = s->cnt.inbytes; - sum->outbytes = s->cnt.outbytes; - } while (u64_stats_fetch_retry(&s->syncp, start)); - } - } -} +static struct lock_class_key __ipvs_est_key; +static void ip_vs_est_calc_phase(struct netns_ipvs *ipvs); +static void ip_vs_est_drain_temp_list(struct netns_ipvs *ipvs); -static void estimation_timer(struct timer_list *t) +static void ip_vs_chain_estimation(struct hlist_head *chain) { struct ip_vs_estimator *e; + struct ip_vs_cpu_stats *c; struct ip_vs_stats *s; u64 rate; - struct netns_ipvs *ipvs = from_timer(ipvs, t, est_timer); - if (!sysctl_run_estimation(ipvs)) - goto skip; + hlist_for_each_entry_rcu(e, chain, list) { + u64 conns, inpkts, outpkts, inbytes, outbytes; + u64 kconns = 0, kinpkts = 0, koutpkts = 0; + u64 kinbytes = 0, koutbytes = 0; + unsigned int start; + int i; + + if (kthread_should_stop()) + break; - spin_lock(&ipvs->est_lock); - list_for_each_entry(e, &ipvs->est_list, list) { s = container_of(e, struct ip_vs_stats, est); + for_each_possible_cpu(i) { + c = per_cpu_ptr(s->cpustats, i); + do { + start = u64_stats_fetch_begin(&c->syncp); + conns = u64_stats_read(&c->cnt.conns); + inpkts = u64_stats_read(&c->cnt.inpkts); + outpkts = u64_stats_read(&c->cnt.outpkts); + inbytes = u64_stats_read(&c->cnt.inbytes); + outbytes = u64_stats_read(&c->cnt.outbytes); + } while (u64_stats_fetch_retry(&c->syncp, start)); + kconns += conns; + kinpkts += inpkts; + koutpkts += outpkts; + kinbytes += inbytes; + koutbytes += outbytes; + } spin_lock(&s->lock); - ip_vs_read_cpu_stats(&s->kstats, s->cpustats); + + s->kstats.conns = kconns; + s->kstats.inpkts = kinpkts; + s->kstats.outpkts = koutpkts; + s->kstats.inbytes = kinbytes; + s->kstats.outbytes = koutbytes; /* scaled by 2^10, but divided 2 seconds */ rate = (s->kstats.conns - e->last_conns) << 9; @@ -133,30 +141,758 @@ static void estimation_timer(struct timer_list *t) e->outbps += ((s64)rate - (s64)e->outbps) >> 2; spin_unlock(&s->lock); } - spin_unlock(&ipvs->est_lock); +} -skip: - mod_timer(&ipvs->est_timer, jiffies + 2*HZ); +static void ip_vs_tick_estimation(struct ip_vs_est_kt_data *kd, int row) +{ + struct ip_vs_est_tick_data *td; + int cid; + + rcu_read_lock(); + td = rcu_dereference(kd->ticks[row]); + if (!td) + goto out; + for_each_set_bit(cid, td->present, IPVS_EST_TICK_CHAINS) { + if (kthread_should_stop()) + break; + ip_vs_chain_estimation(&td->chains[cid]); + cond_resched_rcu(); + td = rcu_dereference(kd->ticks[row]); + if (!td) + break; + } + +out: + rcu_read_unlock(); } -void ip_vs_start_estimator(struct netns_ipvs *ipvs, struct ip_vs_stats *stats) +static int ip_vs_estimation_kthread(void *data) { - struct ip_vs_estimator *est = &stats->est; + struct ip_vs_est_kt_data *kd = data; + struct netns_ipvs *ipvs = kd->ipvs; + int row = kd->est_row; + unsigned long now; + int id = kd->id; + long gap; + + if (id > 0) { + if (!ipvs->est_chain_max) + return 0; + } else { + if (!ipvs->est_chain_max) { + ipvs->est_calc_phase = 1; + /* commit est_calc_phase before reading est_genid */ + smp_mb(); + } + + /* kthread 0 will handle the calc phase */ + if (ipvs->est_calc_phase) + ip_vs_est_calc_phase(ipvs); + } + + while (1) { + if (!id && !hlist_empty(&ipvs->est_temp_list)) + ip_vs_est_drain_temp_list(ipvs); + set_current_state(TASK_IDLE); + if (kthread_should_stop()) + break; + + /* before estimation, check if we should sleep */ + now = jiffies; + gap = kd->est_timer - now; + if (gap > 0) { + if (gap > IPVS_EST_TICK) { + kd->est_timer = now - IPVS_EST_TICK; + gap = IPVS_EST_TICK; + } + schedule_timeout(gap); + } else { + __set_current_state(TASK_RUNNING); + if (gap < -8 * IPVS_EST_TICK) + kd->est_timer = now; + } + + if (kd->tick_len[row]) + ip_vs_tick_estimation(kd, row); + + row++; + if (row >= IPVS_EST_NTICKS) + row = 0; + WRITE_ONCE(kd->est_row, row); + kd->est_timer += IPVS_EST_TICK; + } + __set_current_state(TASK_RUNNING); + + return 0; +} + +/* Schedule stop/start for kthread tasks */ +void ip_vs_est_reload_start(struct netns_ipvs *ipvs) +{ + /* Ignore reloads before first service is added */ + if (!ipvs->enable) + return; + ip_vs_est_stopped_recalc(ipvs); + /* Bump the kthread configuration genid */ + atomic_inc(&ipvs->est_genid); + queue_delayed_work(system_long_wq, &ipvs->est_reload_work, 0); +} + +/* Start kthread task with current configuration */ +int ip_vs_est_kthread_start(struct netns_ipvs *ipvs, + struct ip_vs_est_kt_data *kd) +{ + unsigned long now; + int ret = 0; + long gap; + + lockdep_assert_held(&ipvs->est_mutex); + + if (kd->task) + goto out; + now = jiffies; + gap = kd->est_timer - now; + /* Sync est_timer if task is starting later */ + if (abs(gap) > 4 * IPVS_EST_TICK) + kd->est_timer = now; + kd->task = kthread_create(ip_vs_estimation_kthread, kd, "ipvs-e:%d:%d", + ipvs->gen, kd->id); + if (IS_ERR(kd->task)) { + ret = PTR_ERR(kd->task); + kd->task = NULL; + goto out; + } + + set_user_nice(kd->task, sysctl_est_nice(ipvs)); + set_cpus_allowed_ptr(kd->task, sysctl_est_cpulist(ipvs)); + + pr_info("starting estimator thread %d...\n", kd->id); + wake_up_process(kd->task); + +out: + return ret; +} + +void ip_vs_est_kthread_stop(struct ip_vs_est_kt_data *kd) +{ + if (kd->task) { + pr_info("stopping estimator thread %d...\n", kd->id); + kthread_stop(kd->task); + kd->task = NULL; + } +} + +/* Apply parameters to kthread */ +static void ip_vs_est_set_params(struct netns_ipvs *ipvs, + struct ip_vs_est_kt_data *kd) +{ + kd->chain_max = ipvs->est_chain_max; + /* We are using single chain on RCU preemption */ + if (IPVS_EST_TICK_CHAINS == 1) + kd->chain_max *= IPVS_EST_CHAIN_FACTOR; + kd->tick_max = IPVS_EST_TICK_CHAINS * kd->chain_max; + kd->est_max_count = IPVS_EST_NTICKS * kd->tick_max; +} + +/* Create and start estimation kthread in a free or new array slot */ +static int ip_vs_est_add_kthread(struct netns_ipvs *ipvs) +{ + struct ip_vs_est_kt_data *kd = NULL; + int id = ipvs->est_kt_count; + int ret = -ENOMEM; + void *arr = NULL; + int i; + + if ((unsigned long)ipvs->est_kt_count >= ipvs->est_max_threads && + ipvs->enable && ipvs->est_max_threads) + return -EINVAL; + + mutex_lock(&ipvs->est_mutex); + + for (i = 0; i < id; i++) { + if (!ipvs->est_kt_arr[i]) + break; + } + if (i >= id) { + arr = krealloc_array(ipvs->est_kt_arr, id + 1, + sizeof(struct ip_vs_est_kt_data *), + GFP_KERNEL); + if (!arr) + goto out; + ipvs->est_kt_arr = arr; + } else { + id = i; + } - INIT_LIST_HEAD(&est->list); + kd = kzalloc(sizeof(*kd), GFP_KERNEL); + if (!kd) + goto out; + kd->ipvs = ipvs; + bitmap_fill(kd->avail, IPVS_EST_NTICKS); + kd->est_timer = jiffies; + kd->id = id; + ip_vs_est_set_params(ipvs, kd); + + /* Pre-allocate stats used in calc phase */ + if (!id && !kd->calc_stats) { + kd->calc_stats = ip_vs_stats_alloc(); + if (!kd->calc_stats) + goto out; + } + + /* Start kthread tasks only when services are present */ + if (ipvs->enable && !ip_vs_est_stopped(ipvs)) { + ret = ip_vs_est_kthread_start(ipvs, kd); + if (ret < 0) + goto out; + } + + if (arr) + ipvs->est_kt_count++; + ipvs->est_kt_arr[id] = kd; + kd = NULL; + /* Use most recent kthread for new ests */ + ipvs->est_add_ktid = id; + ret = 0; + +out: + mutex_unlock(&ipvs->est_mutex); + if (kd) { + ip_vs_stats_free(kd->calc_stats); + kfree(kd); + } - spin_lock_bh(&ipvs->est_lock); - list_add(&est->list, &ipvs->est_list); - spin_unlock_bh(&ipvs->est_lock); + return ret; } +/* Select ktid where to add new ests: available, unused or new slot */ +static void ip_vs_est_update_ktid(struct netns_ipvs *ipvs) +{ + int ktid, best = ipvs->est_kt_count; + struct ip_vs_est_kt_data *kd; + + for (ktid = 0; ktid < ipvs->est_kt_count; ktid++) { + kd = ipvs->est_kt_arr[ktid]; + if (kd) { + if (kd->est_count < kd->est_max_count) { + best = ktid; + break; + } + } else if (ktid < best) { + best = ktid; + } + } + ipvs->est_add_ktid = best; +} + +/* Add estimator to current kthread (est_add_ktid) */ +static int ip_vs_enqueue_estimator(struct netns_ipvs *ipvs, + struct ip_vs_estimator *est) +{ + struct ip_vs_est_kt_data *kd = NULL; + struct ip_vs_est_tick_data *td; + int ktid, row, crow, cid, ret; + int delay = est->ktrow; + + BUILD_BUG_ON_MSG(IPVS_EST_TICK_CHAINS > 127, + "Too many chains for ktcid"); + + if (ipvs->est_add_ktid < ipvs->est_kt_count) { + kd = ipvs->est_kt_arr[ipvs->est_add_ktid]; + if (kd) + goto add_est; + } + + ret = ip_vs_est_add_kthread(ipvs); + if (ret < 0) + goto out; + kd = ipvs->est_kt_arr[ipvs->est_add_ktid]; + +add_est: + ktid = kd->id; + /* For small number of estimators prefer to use few ticks, + * otherwise try to add into the last estimated row. + * est_row and add_row point after the row we should use + */ + if (kd->est_count >= 2 * kd->tick_max || delay < IPVS_EST_NTICKS - 1) + crow = READ_ONCE(kd->est_row); + else + crow = kd->add_row; + crow += delay; + if (crow >= IPVS_EST_NTICKS) + crow -= IPVS_EST_NTICKS; + /* Assume initial delay ? */ + if (delay >= IPVS_EST_NTICKS - 1) { + /* Preserve initial delay or decrease it if no space in tick */ + row = crow; + if (crow < IPVS_EST_NTICKS - 1) { + crow++; + row = find_last_bit(kd->avail, crow); + } + if (row >= crow) + row = find_last_bit(kd->avail, IPVS_EST_NTICKS); + } else { + /* Preserve delay or increase it if no space in tick */ + row = IPVS_EST_NTICKS; + if (crow > 0) + row = find_next_bit(kd->avail, IPVS_EST_NTICKS, crow); + if (row >= IPVS_EST_NTICKS) + row = find_first_bit(kd->avail, IPVS_EST_NTICKS); + } + + td = rcu_dereference_protected(kd->ticks[row], 1); + if (!td) { + td = kzalloc(sizeof(*td), GFP_KERNEL); + if (!td) { + ret = -ENOMEM; + goto out; + } + rcu_assign_pointer(kd->ticks[row], td); + } + + cid = find_first_zero_bit(td->full, IPVS_EST_TICK_CHAINS); + + kd->est_count++; + kd->tick_len[row]++; + if (!td->chain_len[cid]) + __set_bit(cid, td->present); + td->chain_len[cid]++; + est->ktid = ktid; + est->ktrow = row; + est->ktcid = cid; + hlist_add_head_rcu(&est->list, &td->chains[cid]); + + if (td->chain_len[cid] >= kd->chain_max) { + __set_bit(cid, td->full); + if (kd->tick_len[row] >= kd->tick_max) + __clear_bit(row, kd->avail); + } + + /* Update est_add_ktid to point to first available/empty kt slot */ + if (kd->est_count == kd->est_max_count) + ip_vs_est_update_ktid(ipvs); + + ret = 0; + +out: + return ret; +} + +/* Start estimation for stats */ +int ip_vs_start_estimator(struct netns_ipvs *ipvs, struct ip_vs_stats *stats) +{ + struct ip_vs_estimator *est = &stats->est; + int ret; + + if (!ipvs->est_max_threads && ipvs->enable) + ipvs->est_max_threads = ip_vs_est_max_threads(ipvs); + + est->ktid = -1; + est->ktrow = IPVS_EST_NTICKS - 1; /* Initial delay */ + + /* We prefer this code to be short, kthread 0 will requeue the + * estimator to available chain. If tasks are disabled, we + * will not allocate much memory, just for kt 0. + */ + ret = 0; + if (!ipvs->est_kt_count || !ipvs->est_kt_arr[0]) + ret = ip_vs_est_add_kthread(ipvs); + if (ret >= 0) + hlist_add_head(&est->list, &ipvs->est_temp_list); + else + INIT_HLIST_NODE(&est->list); + return ret; +} + +static void ip_vs_est_kthread_destroy(struct ip_vs_est_kt_data *kd) +{ + if (kd) { + if (kd->task) { + pr_info("stop unused estimator thread %d...\n", kd->id); + kthread_stop(kd->task); + } + ip_vs_stats_free(kd->calc_stats); + kfree(kd); + } +} + +/* Unlink estimator from chain */ void ip_vs_stop_estimator(struct netns_ipvs *ipvs, struct ip_vs_stats *stats) { struct ip_vs_estimator *est = &stats->est; + struct ip_vs_est_tick_data *td; + struct ip_vs_est_kt_data *kd; + int ktid = est->ktid; + int row = est->ktrow; + int cid = est->ktcid; + + /* Failed to add to chain ? */ + if (hlist_unhashed(&est->list)) + return; + + /* On return, estimator can be freed, dequeue it now */ + + /* In est_temp_list ? */ + if (ktid < 0) { + hlist_del(&est->list); + goto end_kt0; + } + + hlist_del_rcu(&est->list); + kd = ipvs->est_kt_arr[ktid]; + td = rcu_dereference_protected(kd->ticks[row], 1); + __clear_bit(cid, td->full); + td->chain_len[cid]--; + if (!td->chain_len[cid]) + __clear_bit(cid, td->present); + kd->tick_len[row]--; + __set_bit(row, kd->avail); + if (!kd->tick_len[row]) { + RCU_INIT_POINTER(kd->ticks[row], NULL); + kfree_rcu(td); + } + kd->est_count--; + if (kd->est_count) { + /* This kt slot can become available just now, prefer it */ + if (ktid < ipvs->est_add_ktid) + ipvs->est_add_ktid = ktid; + return; + } - spin_lock_bh(&ipvs->est_lock); - list_del(&est->list); - spin_unlock_bh(&ipvs->est_lock); + if (ktid > 0) { + mutex_lock(&ipvs->est_mutex); + ip_vs_est_kthread_destroy(kd); + ipvs->est_kt_arr[ktid] = NULL; + if (ktid == ipvs->est_kt_count - 1) { + ipvs->est_kt_count--; + while (ipvs->est_kt_count > 1 && + !ipvs->est_kt_arr[ipvs->est_kt_count - 1]) + ipvs->est_kt_count--; + } + mutex_unlock(&ipvs->est_mutex); + + /* This slot is now empty, prefer another available kt slot */ + if (ktid == ipvs->est_add_ktid) + ip_vs_est_update_ktid(ipvs); + } + +end_kt0: + /* kt 0 is freed after all other kthreads and chains are empty */ + if (ipvs->est_kt_count == 1 && hlist_empty(&ipvs->est_temp_list)) { + kd = ipvs->est_kt_arr[0]; + if (!kd || !kd->est_count) { + mutex_lock(&ipvs->est_mutex); + if (kd) { + ip_vs_est_kthread_destroy(kd); + ipvs->est_kt_arr[0] = NULL; + } + ipvs->est_kt_count--; + mutex_unlock(&ipvs->est_mutex); + ipvs->est_add_ktid = 0; + } + } +} + +/* Register all ests from est_temp_list to kthreads */ +static void ip_vs_est_drain_temp_list(struct netns_ipvs *ipvs) +{ + struct ip_vs_estimator *est; + + while (1) { + int max = 16; + + mutex_lock(&__ip_vs_mutex); + + while (max-- > 0) { + est = hlist_entry_safe(ipvs->est_temp_list.first, + struct ip_vs_estimator, list); + if (est) { + if (kthread_should_stop()) + goto unlock; + hlist_del_init(&est->list); + if (ip_vs_enqueue_estimator(ipvs, est) >= 0) + continue; + est->ktid = -1; + hlist_add_head(&est->list, + &ipvs->est_temp_list); + /* Abort, some entries will not be estimated + * until next attempt + */ + } + goto unlock; + } + mutex_unlock(&__ip_vs_mutex); + cond_resched(); + } + +unlock: + mutex_unlock(&__ip_vs_mutex); +} + +/* Calculate limits for all kthreads */ +static int ip_vs_est_calc_limits(struct netns_ipvs *ipvs, int *chain_max) +{ + DECLARE_WAIT_QUEUE_HEAD_ONSTACK(wq); + struct ip_vs_est_kt_data *kd; + struct hlist_head chain; + struct ip_vs_stats *s; + int cache_factor = 4; + int i, loops, ntest; + s32 min_est = 0; + ktime_t t1, t2; + int max = 8; + int ret = 1; + s64 diff; + u64 val; + + INIT_HLIST_HEAD(&chain); + mutex_lock(&__ip_vs_mutex); + kd = ipvs->est_kt_arr[0]; + mutex_unlock(&__ip_vs_mutex); + s = kd ? kd->calc_stats : NULL; + if (!s) + goto out; + hlist_add_head(&s->est.list, &chain); + + loops = 1; + /* Get best result from many tests */ + for (ntest = 0; ntest < 12; ntest++) { + if (!(ntest & 3)) { + /* Wait for cpufreq frequency transition */ + wait_event_idle_timeout(wq, kthread_should_stop(), + HZ / 50); + if (!ipvs->enable || kthread_should_stop()) + goto stop; + } + + local_bh_disable(); + rcu_read_lock(); + + /* Put stats in cache */ + ip_vs_chain_estimation(&chain); + + t1 = ktime_get(); + for (i = loops * cache_factor; i > 0; i--) + ip_vs_chain_estimation(&chain); + t2 = ktime_get(); + + rcu_read_unlock(); + local_bh_enable(); + + if (!ipvs->enable || kthread_should_stop()) + goto stop; + cond_resched(); + + diff = ktime_to_ns(ktime_sub(t2, t1)); + if (diff <= 1 * NSEC_PER_USEC) { + /* Do more loops on low time resolution */ + loops *= 2; + continue; + } + if (diff >= NSEC_PER_SEC) + continue; + val = diff; + do_div(val, loops); + if (!min_est || val < min_est) { + min_est = val; + /* goal: 95usec per chain */ + val = 95 * NSEC_PER_USEC; + if (val >= min_est) { + do_div(val, min_est); + max = (int)val; + } else { + max = 1; + } + } + } + +out: + if (s) + hlist_del_init(&s->est.list); + *chain_max = max; + return ret; + +stop: + ret = 0; + goto out; +} + +/* Calculate the parameters and apply them in context of kt #0 + * ECP: est_calc_phase + * ECM: est_chain_max + * ECP ECM Insert Chain enable Description + * --------------------------------------------------------------------------- + * 0 0 est_temp_list 0 create kt #0 context + * 0 0 est_temp_list 0->1 service added, start kthread #0 task + * 0->1 0 est_temp_list 1 kt task #0 started, enters calc phase + * 1 0 est_temp_list 1 kt #0: determine est_chain_max, + * stop tasks, move ests to est_temp_list + * and free kd for kthreads 1..last + * 1->0 0->N kt chains 1 ests can go to kthreads + * 0 N kt chains 1 drain est_temp_list, create new kthread + * contexts, start tasks, estimate + */ +static void ip_vs_est_calc_phase(struct netns_ipvs *ipvs) +{ + int genid = atomic_read(&ipvs->est_genid); + struct ip_vs_est_tick_data *td; + struct ip_vs_est_kt_data *kd; + struct ip_vs_estimator *est; + struct ip_vs_stats *stats; + int id, row, cid, delay; + bool last, last_td; + int chain_max; + int step; + + if (!ip_vs_est_calc_limits(ipvs, &chain_max)) + return; + + mutex_lock(&__ip_vs_mutex); + + /* Stop all other tasks, so that we can immediately move the + * estimators to est_temp_list without RCU grace period + */ + mutex_lock(&ipvs->est_mutex); + for (id = 1; id < ipvs->est_kt_count; id++) { + /* netns clean up started, abort */ + if (!ipvs->enable) + goto unlock2; + kd = ipvs->est_kt_arr[id]; + if (!kd) + continue; + ip_vs_est_kthread_stop(kd); + } + mutex_unlock(&ipvs->est_mutex); + + /* Move all estimators to est_temp_list but carefully, + * all estimators and kthread data can be released while + * we reschedule. Even for kthread 0. + */ + step = 0; + + /* Order entries in est_temp_list in ascending delay, so now + * walk delay(desc), id(desc), cid(asc) + */ + delay = IPVS_EST_NTICKS; + +next_delay: + delay--; + if (delay < 0) + goto end_dequeue; + +last_kt: + /* Destroy contexts backwards */ + id = ipvs->est_kt_count; + +next_kt: + if (!ipvs->enable || kthread_should_stop()) + goto unlock; + id--; + if (id < 0) + goto next_delay; + kd = ipvs->est_kt_arr[id]; + if (!kd) + goto next_kt; + /* kt 0 can exist with empty chains */ + if (!id && kd->est_count <= 1) + goto next_delay; + + row = kd->est_row + delay; + if (row >= IPVS_EST_NTICKS) + row -= IPVS_EST_NTICKS; + td = rcu_dereference_protected(kd->ticks[row], 1); + if (!td) + goto next_kt; + + cid = 0; + +walk_chain: + if (kthread_should_stop()) + goto unlock; + step++; + if (!(step & 63)) { + /* Give chance estimators to be added (to est_temp_list) + * and deleted (releasing kthread contexts) + */ + mutex_unlock(&__ip_vs_mutex); + cond_resched(); + mutex_lock(&__ip_vs_mutex); + + /* Current kt released ? */ + if (id >= ipvs->est_kt_count) + goto last_kt; + if (kd != ipvs->est_kt_arr[id]) + goto next_kt; + /* Current td released ? */ + if (td != rcu_dereference_protected(kd->ticks[row], 1)) + goto next_kt; + /* No fatal changes on the current kd and td */ + } + est = hlist_entry_safe(td->chains[cid].first, struct ip_vs_estimator, + list); + if (!est) { + cid++; + if (cid >= IPVS_EST_TICK_CHAINS) + goto next_kt; + goto walk_chain; + } + /* We can cheat and increase est_count to protect kt 0 context + * from release but we prefer to keep the last estimator + */ + last = kd->est_count <= 1; + /* Do not free kt #0 data */ + if (!id && last) + goto next_delay; + last_td = kd->tick_len[row] <= 1; + stats = container_of(est, struct ip_vs_stats, est); + ip_vs_stop_estimator(ipvs, stats); + /* Tasks are stopped, move without RCU grace period */ + est->ktid = -1; + est->ktrow = row - kd->est_row; + if (est->ktrow < 0) + est->ktrow += IPVS_EST_NTICKS; + hlist_add_head(&est->list, &ipvs->est_temp_list); + /* kd freed ? */ + if (last) + goto next_kt; + /* td freed ? */ + if (last_td) + goto next_kt; + goto walk_chain; + +end_dequeue: + /* All estimators removed while calculating ? */ + if (!ipvs->est_kt_count) + goto unlock; + kd = ipvs->est_kt_arr[0]; + if (!kd) + goto unlock; + kd->add_row = kd->est_row; + ipvs->est_chain_max = chain_max; + ip_vs_est_set_params(ipvs, kd); + + pr_info("using max %d ests per chain, %d per kthread\n", + kd->chain_max, kd->est_max_count); + + /* Try to keep tot_stats in kt0, enqueue it early */ + if (ipvs->tot_stats && !hlist_unhashed(&ipvs->tot_stats->s.est.list) && + ipvs->tot_stats->s.est.ktid == -1) { + hlist_del(&ipvs->tot_stats->s.est.list); + hlist_add_head(&ipvs->tot_stats->s.est.list, + &ipvs->est_temp_list); + } + + mutex_lock(&ipvs->est_mutex); + + /* We completed the calc phase, new calc phase not requested */ + if (genid == atomic_read(&ipvs->est_genid)) + ipvs->est_calc_phase = 0; + +unlock2: + mutex_unlock(&ipvs->est_mutex); + +unlock: + mutex_unlock(&__ip_vs_mutex); } void ip_vs_zero_estimator(struct ip_vs_stats *stats) @@ -191,14 +927,25 @@ void ip_vs_read_estimator(struct ip_vs_kstats *dst, struct ip_vs_stats *stats) int __net_init ip_vs_estimator_net_init(struct netns_ipvs *ipvs) { - INIT_LIST_HEAD(&ipvs->est_list); - spin_lock_init(&ipvs->est_lock); - timer_setup(&ipvs->est_timer, estimation_timer, 0); - mod_timer(&ipvs->est_timer, jiffies + 2 * HZ); + INIT_HLIST_HEAD(&ipvs->est_temp_list); + ipvs->est_kt_arr = NULL; + ipvs->est_max_threads = 0; + ipvs->est_calc_phase = 0; + ipvs->est_chain_max = 0; + ipvs->est_kt_count = 0; + ipvs->est_add_ktid = 0; + atomic_set(&ipvs->est_genid, 0); + atomic_set(&ipvs->est_genid_done, 0); + __mutex_init(&ipvs->est_mutex, "ipvs->est_mutex", &__ipvs_est_key); return 0; } void __net_exit ip_vs_estimator_net_cleanup(struct netns_ipvs *ipvs) { - del_timer_sync(&ipvs->est_timer); + int i; + + for (i = 0; i < ipvs->est_kt_count; i++) + ip_vs_est_kthread_destroy(ipvs->est_kt_arr[i]); + kfree(ipvs->est_kt_arr); + mutex_destroy(&ipvs->est_mutex); } diff --git a/net/netfilter/ipvs/ip_vs_lblc.c b/net/netfilter/ipvs/ip_vs_lblc.c index 7ac7473e3804..1b87214d385e 100644 --- a/net/netfilter/ipvs/ip_vs_lblc.c +++ b/net/netfilter/ipvs/ip_vs_lblc.c @@ -384,7 +384,7 @@ static void ip_vs_lblc_done_svc(struct ip_vs_service *svc) struct ip_vs_lblc_table *tbl = svc->sched_data; /* remove periodic timer */ - del_timer_sync(&tbl->periodic_timer); + timer_shutdown_sync(&tbl->periodic_timer); /* got to clean up table entries here */ ip_vs_lblc_flush(svc); diff --git a/net/netfilter/ipvs/ip_vs_lblcr.c b/net/netfilter/ipvs/ip_vs_lblcr.c index 77c323c36a88..ad8f5fea6d3a 100644 --- a/net/netfilter/ipvs/ip_vs_lblcr.c +++ b/net/netfilter/ipvs/ip_vs_lblcr.c @@ -547,7 +547,7 @@ static void ip_vs_lblcr_done_svc(struct ip_vs_service *svc) struct ip_vs_lblcr_table *tbl = svc->sched_data; /* remove periodic timer */ - del_timer_sync(&tbl->periodic_timer); + timer_shutdown_sync(&tbl->periodic_timer); /* got to clean up table entries here */ ip_vs_lblcr_flush(svc); diff --git a/net/netfilter/ipvs/ip_vs_sync.c b/net/netfilter/ipvs/ip_vs_sync.c index a56fd0b5a430..4963fec815da 100644 --- a/net/netfilter/ipvs/ip_vs_sync.c +++ b/net/netfilter/ipvs/ip_vs_sync.c @@ -1617,7 +1617,7 @@ ip_vs_receive(struct socket *sock, char *buffer, const size_t buflen) EnterFunction(7); /* Receive a packet */ - iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, buflen); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, buflen); len = sock_recvmsg(sock, &msg, MSG_DONTWAIT); if (len < 0) return len; diff --git a/net/netfilter/ipvs/ip_vs_twos.c b/net/netfilter/ipvs/ip_vs_twos.c index f2579fc9c75b..3308e4cc740a 100644 --- a/net/netfilter/ipvs/ip_vs_twos.c +++ b/net/netfilter/ipvs/ip_vs_twos.c @@ -71,8 +71,8 @@ static struct ip_vs_dest *ip_vs_twos_schedule(struct ip_vs_service *svc, * from 0 to total_weight */ total_weight += 1; - rweight1 = prandom_u32_max(total_weight); - rweight2 = prandom_u32_max(total_weight); + rweight1 = get_random_u32_below(total_weight); + rweight2 = get_random_u32_below(total_weight); /* Pick two weighted servers */ list_for_each_entry_rcu(dest, &svc->destinations, n_list) { diff --git a/net/netfilter/nf_conntrack_bpf.c b/net/netfilter/nf_conntrack_bpf.c index 8639e7efd0e2..24002bc61e07 100644 --- a/net/netfilter/nf_conntrack_bpf.c +++ b/net/netfilter/nf_conntrack_bpf.c @@ -191,19 +191,16 @@ BTF_ID(struct, nf_conn___init) /* Check writes into `struct nf_conn` */ static int _nf_conntrack_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, - u32 *next_btf_id, - enum bpf_type_flag *flag) + const struct bpf_reg_state *reg, + int off, int size, enum bpf_access_type atype, + u32 *next_btf_id, enum bpf_type_flag *flag) { - const struct btf_type *ncit; - const struct btf_type *nct; + const struct btf_type *ncit, *nct, *t; size_t end; - ncit = btf_type_by_id(btf, btf_nf_conn_ids[1]); - nct = btf_type_by_id(btf, btf_nf_conn_ids[0]); - + ncit = btf_type_by_id(reg->btf, btf_nf_conn_ids[1]); + nct = btf_type_by_id(reg->btf, btf_nf_conn_ids[0]); + t = btf_type_by_id(reg->btf, reg->btf_id); if (t != nct && t != ncit) { bpf_log(log, "only read is supported\n"); return -EACCES; diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index f97bda06d2a9..496c4920505b 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -211,28 +211,24 @@ static u32 hash_conntrack_raw(const struct nf_conntrack_tuple *tuple, unsigned int zoneid, const struct net *net) { - struct { - struct nf_conntrack_man src; - union nf_inet_addr dst_addr; - unsigned int zone; - u32 net_mix; - u16 dport; - u16 proto; - } __aligned(SIPHASH_ALIGNMENT) combined; + u64 a, b, c, d; get_random_once(&nf_conntrack_hash_rnd, sizeof(nf_conntrack_hash_rnd)); - memset(&combined, 0, sizeof(combined)); + /* The direction must be ignored, handle usable tuplehash members manually */ + a = (u64)tuple->src.u3.all[0] << 32 | tuple->src.u3.all[3]; + b = (u64)tuple->dst.u3.all[0] << 32 | tuple->dst.u3.all[3]; - /* The direction must be ignored, so handle usable members manually. */ - combined.src = tuple->src; - combined.dst_addr = tuple->dst.u3; - combined.zone = zoneid; - combined.net_mix = net_hash_mix(net); - combined.dport = (__force __u16)tuple->dst.u.all; - combined.proto = tuple->dst.protonum; + c = (__force u64)tuple->src.u.all << 32 | (__force u64)tuple->dst.u.all << 16; + c |= tuple->dst.protonum; - return (u32)siphash(&combined, sizeof(combined), &nf_conntrack_hash_rnd); + d = (u64)zoneid << 32 | net_hash_mix(net); + + /* IPv4: u3.all[1,2,3] == 0 */ + c ^= (u64)tuple->src.u3.all[1] << 32 | tuple->src.u3.all[2]; + d += (u64)tuple->dst.u3.all[1] << 32 | tuple->dst.u3.all[2]; + + return (u32)siphash_4u64(a, b, c, d, &nf_conntrack_hash_rnd); } static u32 scale_hash(u32 hash) @@ -891,7 +887,7 @@ nf_conntrack_hash_check_insert(struct nf_conn *ct) zone = nf_ct_zone(ct); if (!nf_ct_ext_valid_pre(ct->ext)) { - NF_CT_STAT_INC(net, insert_failed); + NF_CT_STAT_INC_ATOMIC(net, insert_failed); return -ETIMEDOUT; } @@ -906,7 +902,7 @@ nf_conntrack_hash_check_insert(struct nf_conn *ct) nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_REPLY)); } while (nf_conntrack_double_lock(net, hash, reply_hash, sequence)); - max_chainlen = MIN_CHAINLEN + prandom_u32_max(MAX_CHAINLEN); + max_chainlen = MIN_CHAINLEN + get_random_u32_below(MAX_CHAINLEN); /* See if there's one in the list already, including reverse */ hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[hash], hnnode) { @@ -938,7 +934,7 @@ nf_conntrack_hash_check_insert(struct nf_conn *ct) if (!nf_ct_ext_valid_post(ct->ext)) { nf_ct_kill(ct); - NF_CT_STAT_INC(net, drop); + NF_CT_STAT_INC_ATOMIC(net, drop); return -ETIMEDOUT; } @@ -1227,7 +1223,7 @@ __nf_conntrack_confirm(struct sk_buff *skb) goto dying; } - max_chainlen = MIN_CHAINLEN + prandom_u32_max(MAX_CHAINLEN); + max_chainlen = MIN_CHAINLEN + get_random_u32_below(MAX_CHAINLEN); /* See if there's one in the list already, including reverse: NAT could have grabbed it without realizing, since we're not in the hash. If there is, we lost race. */ @@ -1275,7 +1271,7 @@ chaintoolong: */ if (!nf_ct_ext_valid_post(ct->ext)) { nf_ct_kill(ct); - NF_CT_STAT_INC(net, drop); + NF_CT_STAT_INC_ATOMIC(net, drop); return NF_DROP; } @@ -1781,7 +1777,7 @@ init_conntrack(struct net *net, struct nf_conn *tmpl, } #ifdef CONFIG_NF_CONNTRACK_MARK - ct->mark = exp->master->mark; + ct->mark = READ_ONCE(exp->master->mark); #endif #ifdef CONFIG_NF_CONNTRACK_SECMARK ct->secmark = exp->master->secmark; diff --git a/net/netfilter/nf_conntrack_helper.c b/net/netfilter/nf_conntrack_helper.c index ff737a76052e..48ea6d0264b5 100644 --- a/net/netfilter/nf_conntrack_helper.c +++ b/net/netfilter/nf_conntrack_helper.c @@ -26,7 +26,9 @@ #include <net/netfilter/nf_conntrack_extend.h> #include <net/netfilter/nf_conntrack_helper.h> #include <net/netfilter/nf_conntrack_l4proto.h> +#include <net/netfilter/nf_conntrack_seqadj.h> #include <net/netfilter/nf_log.h> +#include <net/ip.h> static DEFINE_MUTEX(nf_ct_helper_mutex); struct hlist_head *nf_ct_helper_hash __read_mostly; @@ -240,6 +242,104 @@ int __nf_ct_try_assign_helper(struct nf_conn *ct, struct nf_conn *tmpl, } EXPORT_SYMBOL_GPL(__nf_ct_try_assign_helper); +/* 'skb' should already be pulled to nh_ofs. */ +int nf_ct_helper(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, u16 proto) +{ + const struct nf_conntrack_helper *helper; + const struct nf_conn_help *help; + unsigned int protoff; + int err; + + if (ctinfo == IP_CT_RELATED_REPLY) + return NF_ACCEPT; + + help = nfct_help(ct); + if (!help) + return NF_ACCEPT; + + helper = rcu_dereference(help->helper); + if (!helper) + return NF_ACCEPT; + + if (helper->tuple.src.l3num != NFPROTO_UNSPEC && + helper->tuple.src.l3num != proto) + return NF_ACCEPT; + + switch (proto) { + case NFPROTO_IPV4: + protoff = ip_hdrlen(skb); + proto = ip_hdr(skb)->protocol; + break; + case NFPROTO_IPV6: { + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + __be16 frag_off; + int ofs; + + ofs = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, + &frag_off); + if (ofs < 0 || (frag_off & htons(~0x7)) != 0) { + pr_debug("proto header not found\n"); + return NF_ACCEPT; + } + protoff = ofs; + proto = nexthdr; + break; + } + default: + WARN_ONCE(1, "helper invoked on non-IP family!"); + return NF_DROP; + } + + if (helper->tuple.dst.protonum != proto) + return NF_ACCEPT; + + err = helper->help(skb, protoff, ct, ctinfo); + if (err != NF_ACCEPT) + return err; + + /* Adjust seqs after helper. This is needed due to some helpers (e.g., + * FTP with NAT) adusting the TCP payload size when mangling IP + * addresses and/or port numbers in the text-based control connection. + */ + if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && + !nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) + return NF_DROP; + return NF_ACCEPT; +} +EXPORT_SYMBOL_GPL(nf_ct_helper); + +int nf_ct_add_helper(struct nf_conn *ct, const char *name, u8 family, + u8 proto, bool nat, struct nf_conntrack_helper **hp) +{ + struct nf_conntrack_helper *helper; + struct nf_conn_help *help; + int ret = 0; + + helper = nf_conntrack_helper_try_module_get(name, family, proto); + if (!helper) + return -EINVAL; + + help = nf_ct_helper_ext_add(ct, GFP_KERNEL); + if (!help) { + nf_conntrack_helper_put(helper); + return -ENOMEM; + } +#if IS_ENABLED(CONFIG_NF_NAT) + if (nat) { + ret = nf_nat_helper_try_module_get(name, family, proto); + if (ret) { + nf_conntrack_helper_put(helper); + return ret; + } + } +#endif + rcu_assign_pointer(help->helper, helper); + *hp = helper; + return ret; +} +EXPORT_SYMBOL_GPL(nf_ct_add_helper); + /* appropriate ct lock protecting must be taken by caller */ static int unhelp(struct nf_conn *ct, void *me) { diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c index 7562b215b932..1286ae7d4609 100644 --- a/net/netfilter/nf_conntrack_netlink.c +++ b/net/netfilter/nf_conntrack_netlink.c @@ -330,7 +330,12 @@ nla_put_failure: #ifdef CONFIG_NF_CONNTRACK_MARK static int ctnetlink_dump_mark(struct sk_buff *skb, const struct nf_conn *ct) { - if (nla_put_be32(skb, CTA_MARK, htonl(ct->mark))) + u32 mark = READ_ONCE(ct->mark); + + if (!mark) + return 0; + + if (nla_put_be32(skb, CTA_MARK, htonl(mark))) goto nla_put_failure; return 0; @@ -826,8 +831,8 @@ ctnetlink_conntrack_event(unsigned int events, const struct nf_ct_event *item) } #ifdef CONFIG_NF_CONNTRACK_MARK - if ((events & (1 << IPCT_MARK) || ct->mark) - && ctnetlink_dump_mark(skb, ct) < 0) + if (events & (1 << IPCT_MARK) && + ctnetlink_dump_mark(skb, ct) < 0) goto nla_put_failure; #endif nlmsg_end(skb, nlh); @@ -1154,7 +1159,7 @@ static int ctnetlink_filter_match(struct nf_conn *ct, void *data) } #ifdef CONFIG_NF_CONNTRACK_MARK - if ((ct->mark & filter->mark.mask) != filter->mark.val) + if ((READ_ONCE(ct->mark) & filter->mark.mask) != filter->mark.val) goto ignore_entry; #endif status = (u32)READ_ONCE(ct->status); @@ -2002,9 +2007,9 @@ static void ctnetlink_change_mark(struct nf_conn *ct, mask = ~ntohl(nla_get_be32(cda[CTA_MARK_MASK])); mark = ntohl(nla_get_be32(cda[CTA_MARK])); - newmark = (ct->mark & mask) ^ mark; - if (newmark != ct->mark) - ct->mark = newmark; + newmark = (READ_ONCE(ct->mark) & mask) ^ mark; + if (newmark != READ_ONCE(ct->mark)) + WRITE_ONCE(ct->mark, newmark); } #endif @@ -2730,7 +2735,7 @@ static int __ctnetlink_glue_build(struct sk_buff *skb, struct nf_conn *ct) goto nla_put_failure; #ifdef CONFIG_NF_CONNTRACK_MARK - if (ct->mark && ctnetlink_dump_mark(skb, ct) < 0) + if (ctnetlink_dump_mark(skb, ct) < 0) goto nla_put_failure; #endif if (ctnetlink_dump_labels(skb, ct) < 0) diff --git a/net/netfilter/nf_conntrack_proto.c b/net/netfilter/nf_conntrack_proto.c index 895b09cbd7cf..ccef340be575 100644 --- a/net/netfilter/nf_conntrack_proto.c +++ b/net/netfilter/nf_conntrack_proto.c @@ -121,17 +121,64 @@ const struct nf_conntrack_l4proto *nf_ct_l4proto_find(u8 l4proto) }; EXPORT_SYMBOL_GPL(nf_ct_l4proto_find); -unsigned int nf_confirm(struct sk_buff *skb, unsigned int protoff, - struct nf_conn *ct, enum ip_conntrack_info ctinfo) +static bool in_vrf_postrouting(const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (state->hook == NF_INET_POST_ROUTING && + netif_is_l3_master(state->out)) + return true; +#endif + return false; +} + +unsigned int nf_confirm(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) { const struct nf_conn_help *help; + enum ip_conntrack_info ctinfo; + unsigned int protoff; + struct nf_conn *ct; + bool seqadj_needed; + __be16 frag_off; + int start; + u8 pnum; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || in_vrf_postrouting(state)) + return NF_ACCEPT; help = nfct_help(ct); + + seqadj_needed = test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && !nf_is_loopback_packet(skb); + if (!help && !seqadj_needed) + return nf_conntrack_confirm(skb); + + /* helper->help() do not expect ICMP packets */ + if (ctinfo == IP_CT_RELATED_REPLY) + return nf_conntrack_confirm(skb); + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + protoff = skb_network_offset(skb) + ip_hdrlen(skb); + break; + case NFPROTO_IPV6: + pnum = ipv6_hdr(skb)->nexthdr; + start = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, &frag_off); + if (start < 0 || (frag_off & htons(~0x7)) != 0) + return nf_conntrack_confirm(skb); + + protoff = start; + break; + default: + return nf_conntrack_confirm(skb); + } + if (help) { const struct nf_conntrack_helper *helper; int ret; - /* rcu_read_lock()ed by nf_hook_thresh */ + /* rcu_read_lock()ed by nf_hook */ helper = rcu_dereference(help->helper); if (helper) { ret = helper->help(skb, @@ -142,12 +189,10 @@ unsigned int nf_confirm(struct sk_buff *skb, unsigned int protoff, } } - if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && - !nf_is_loopback_packet(skb)) { - if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { - NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); - return NF_DROP; - } + if (seqadj_needed && + !nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { + NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); + return NF_DROP; } /* We've seen it coming out the other side: confirm it */ @@ -155,35 +200,6 @@ unsigned int nf_confirm(struct sk_buff *skb, unsigned int protoff, } EXPORT_SYMBOL_GPL(nf_confirm); -static bool in_vrf_postrouting(const struct nf_hook_state *state) -{ -#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) - if (state->hook == NF_INET_POST_ROUTING && - netif_is_l3_master(state->out)) - return true; -#endif - return false; -} - -static unsigned int ipv4_confirm(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) -{ - enum ip_conntrack_info ctinfo; - struct nf_conn *ct; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return nf_conntrack_confirm(skb); - - if (in_vrf_postrouting(state)) - return NF_ACCEPT; - - return nf_confirm(skb, - skb_network_offset(skb) + ip_hdrlen(skb), - ct, ctinfo); -} - static unsigned int ipv4_conntrack_in(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) @@ -230,13 +246,13 @@ static const struct nf_hook_ops ipv4_conntrack_ops[] = { .priority = NF_IP_PRI_CONNTRACK, }, { - .hook = ipv4_confirm, + .hook = nf_confirm, .pf = NFPROTO_IPV4, .hooknum = NF_INET_POST_ROUTING, .priority = NF_IP_PRI_CONNTRACK_CONFIRM, }, { - .hook = ipv4_confirm, + .hook = nf_confirm, .pf = NFPROTO_IPV4, .hooknum = NF_INET_LOCAL_IN, .priority = NF_IP_PRI_CONNTRACK_CONFIRM, @@ -373,33 +389,6 @@ static struct nf_sockopt_ops so_getorigdst6 = { .owner = THIS_MODULE, }; -static unsigned int ipv6_confirm(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - unsigned char pnum = ipv6_hdr(skb)->nexthdr; - __be16 frag_off; - int protoff; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return nf_conntrack_confirm(skb); - - if (in_vrf_postrouting(state)) - return NF_ACCEPT; - - protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, - &frag_off); - if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { - pr_debug("proto header not found\n"); - return nf_conntrack_confirm(skb); - } - - return nf_confirm(skb, protoff, ct, ctinfo); -} - static unsigned int ipv6_conntrack_in(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) @@ -428,13 +417,13 @@ static const struct nf_hook_ops ipv6_conntrack_ops[] = { .priority = NF_IP6_PRI_CONNTRACK, }, { - .hook = ipv6_confirm, + .hook = nf_confirm, .pf = NFPROTO_IPV6, .hooknum = NF_INET_POST_ROUTING, .priority = NF_IP6_PRI_LAST, }, { - .hook = ipv6_confirm, + .hook = nf_confirm, .pf = NFPROTO_IPV6, .hooknum = NF_INET_LOCAL_IN, .priority = NF_IP6_PRI_LAST - 1, diff --git a/net/netfilter/nf_conntrack_proto_icmpv6.c b/net/netfilter/nf_conntrack_proto_icmpv6.c index 61e3b05cf02c..1020d67600a9 100644 --- a/net/netfilter/nf_conntrack_proto_icmpv6.c +++ b/net/netfilter/nf_conntrack_proto_icmpv6.c @@ -129,6 +129,56 @@ static void icmpv6_error_log(const struct sk_buff *skb, nf_l4proto_log_invalid(skb, state, IPPROTO_ICMPV6, "%s", msg); } +static noinline_for_stack int +nf_conntrack_icmpv6_redirect(struct nf_conn *tmpl, struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + u8 hl = ipv6_hdr(skb)->hop_limit; + union nf_inet_addr outer_daddr; + union { + struct nd_opt_hdr nd_opt; + struct rd_msg rd_msg; + } tmp; + const struct nd_opt_hdr *nd_opt; + const struct rd_msg *rd_msg; + + rd_msg = skb_header_pointer(skb, dataoff, sizeof(*rd_msg), &tmp.rd_msg); + if (!rd_msg) { + icmpv6_error_log(skb, state, "short redirect"); + return -NF_ACCEPT; + } + + if (rd_msg->icmph.icmp6_code != 0) + return NF_ACCEPT; + + if (hl != 255 || !(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) { + icmpv6_error_log(skb, state, "invalid saddr or hoplimit for redirect"); + return -NF_ACCEPT; + } + + dataoff += sizeof(*rd_msg); + + /* warning: rd_msg no longer usable after this call */ + nd_opt = skb_header_pointer(skb, dataoff, sizeof(*nd_opt), &tmp.nd_opt); + if (!nd_opt || nd_opt->nd_opt_len == 0) { + icmpv6_error_log(skb, state, "redirect without options"); + return -NF_ACCEPT; + } + + /* We could call ndisc_parse_options(), but it would need + * skb_linearize() and a bit more work. + */ + if (nd_opt->nd_opt_type != ND_OPT_REDIRECT_HDR) + return NF_ACCEPT; + + memcpy(&outer_daddr.ip6, &ipv6_hdr(skb)->daddr, + sizeof(outer_daddr.ip6)); + dataoff += 8; + return nf_conntrack_inet_error(tmpl, skb, dataoff, state, + IPPROTO_ICMPV6, &outer_daddr); +} + int nf_conntrack_icmpv6_error(struct nf_conn *tmpl, struct sk_buff *skb, unsigned int dataoff, @@ -159,6 +209,9 @@ int nf_conntrack_icmpv6_error(struct nf_conn *tmpl, return NF_ACCEPT; } + if (icmp6h->icmp6_type == NDISC_REDIRECT) + return nf_conntrack_icmpv6_redirect(tmpl, skb, dataoff, state); + /* is not error message ? */ if (icmp6h->icmp6_type >= 128) return NF_ACCEPT; diff --git a/net/netfilter/nf_conntrack_proto_sctp.c b/net/netfilter/nf_conntrack_proto_sctp.c index 5a936334b517..011d414038ea 100644 --- a/net/netfilter/nf_conntrack_proto_sctp.c +++ b/net/netfilter/nf_conntrack_proto_sctp.c @@ -27,22 +27,16 @@ #include <net/netfilter/nf_conntrack_ecache.h> #include <net/netfilter/nf_conntrack_timeout.h> -/* FIXME: Examine ipfilter's timeouts and conntrack transitions more - closely. They're more complex. --RR - - And so for me for SCTP :D -Kiran */ - static const char *const sctp_conntrack_names[] = { - "NONE", - "CLOSED", - "COOKIE_WAIT", - "COOKIE_ECHOED", - "ESTABLISHED", - "SHUTDOWN_SENT", - "SHUTDOWN_RECD", - "SHUTDOWN_ACK_SENT", - "HEARTBEAT_SENT", - "HEARTBEAT_ACKED", + [SCTP_CONNTRACK_NONE] = "NONE", + [SCTP_CONNTRACK_CLOSED] = "CLOSED", + [SCTP_CONNTRACK_COOKIE_WAIT] = "COOKIE_WAIT", + [SCTP_CONNTRACK_COOKIE_ECHOED] = "COOKIE_ECHOED", + [SCTP_CONNTRACK_ESTABLISHED] = "ESTABLISHED", + [SCTP_CONNTRACK_SHUTDOWN_SENT] = "SHUTDOWN_SENT", + [SCTP_CONNTRACK_SHUTDOWN_RECD] = "SHUTDOWN_RECD", + [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT] = "SHUTDOWN_ACK_SENT", + [SCTP_CONNTRACK_HEARTBEAT_SENT] = "HEARTBEAT_SENT", }; #define SECS * HZ @@ -54,12 +48,11 @@ static const unsigned int sctp_timeouts[SCTP_CONNTRACK_MAX] = { [SCTP_CONNTRACK_CLOSED] = 10 SECS, [SCTP_CONNTRACK_COOKIE_WAIT] = 3 SECS, [SCTP_CONNTRACK_COOKIE_ECHOED] = 3 SECS, - [SCTP_CONNTRACK_ESTABLISHED] = 5 DAYS, + [SCTP_CONNTRACK_ESTABLISHED] = 210 SECS, [SCTP_CONNTRACK_SHUTDOWN_SENT] = 300 SECS / 1000, [SCTP_CONNTRACK_SHUTDOWN_RECD] = 300 SECS / 1000, [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT] = 3 SECS, [SCTP_CONNTRACK_HEARTBEAT_SENT] = 30 SECS, - [SCTP_CONNTRACK_HEARTBEAT_ACKED] = 210 SECS, }; #define SCTP_FLAG_HEARTBEAT_VTAG_FAILED 1 @@ -73,7 +66,6 @@ static const unsigned int sctp_timeouts[SCTP_CONNTRACK_MAX] = { #define sSR SCTP_CONNTRACK_SHUTDOWN_RECD #define sSA SCTP_CONNTRACK_SHUTDOWN_ACK_SENT #define sHS SCTP_CONNTRACK_HEARTBEAT_SENT -#define sHA SCTP_CONNTRACK_HEARTBEAT_ACKED #define sIV SCTP_CONNTRACK_MAX /* @@ -90,15 +82,12 @@ COOKIE WAIT - We have seen an INIT chunk in the original direction, or als COOKIE ECHOED - We have seen a COOKIE_ECHO chunk in the original direction. ESTABLISHED - We have seen a COOKIE_ACK in the reply direction. SHUTDOWN_SENT - We have seen a SHUTDOWN chunk in the original direction. -SHUTDOWN_RECD - We have seen a SHUTDOWN chunk in the reply directoin. +SHUTDOWN_RECD - We have seen a SHUTDOWN chunk in the reply direction. SHUTDOWN_ACK_SENT - We have seen a SHUTDOWN_ACK chunk in the direction opposite to that of the SHUTDOWN chunk. CLOSED - We have seen a SHUTDOWN_COMPLETE chunk in the direction of the SHUTDOWN chunk. Connection is closed. HEARTBEAT_SENT - We have seen a HEARTBEAT in a new flow. -HEARTBEAT_ACKED - We have seen a HEARTBEAT-ACK in the direction opposite to - that of the HEARTBEAT chunk. Secondary connection is - established. */ /* TODO @@ -115,33 +104,33 @@ cookie echoed to closed. static const u8 sctp_conntracks[2][11][SCTP_CONNTRACK_MAX] = { { /* ORIGINAL */ -/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS, sHA */ -/* init */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCW, sHA}, -/* init_ack */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL, sHA}, -/* abort */ {sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, -/* shutdown */ {sCL, sCL, sCW, sCE, sSS, sSS, sSR, sSA, sCL, sSS}, -/* shutdown_ack */ {sSA, sCL, sCW, sCE, sES, sSA, sSA, sSA, sSA, sHA}, -/* error */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL, sHA},/* Can't have Stale cookie*/ -/* cookie_echo */ {sCL, sCL, sCE, sCE, sES, sSS, sSR, sSA, sCL, sHA},/* 5.2.4 - Big TODO */ -/* cookie_ack */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL, sHA},/* Can't come in orig dir */ -/* shutdown_comp*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sCL, sCL, sHA}, -/* heartbeat */ {sHS, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS, sHA}, -/* heartbeat_ack*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS, sHA} +/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */ +/* init */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCW}, +/* init_ack */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL}, +/* abort */ {sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, +/* shutdown */ {sCL, sCL, sCW, sCE, sSS, sSS, sSR, sSA, sCL}, +/* shutdown_ack */ {sSA, sCL, sCW, sCE, sES, sSA, sSA, sSA, sSA}, +/* error */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL},/* Can't have Stale cookie*/ +/* cookie_echo */ {sCL, sCL, sCE, sCE, sES, sSS, sSR, sSA, sCL},/* 5.2.4 - Big TODO */ +/* cookie_ack */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL},/* Can't come in orig dir */ +/* shutdown_comp*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sCL, sCL}, +/* heartbeat */ {sHS, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, +/* heartbeat_ack*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, }, { /* REPLY */ -/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS, sHA */ -/* init */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV, sHA},/* INIT in sCL Big TODO */ -/* init_ack */ {sIV, sCW, sCW, sCE, sES, sSS, sSR, sSA, sIV, sHA}, -/* abort */ {sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sIV, sCL}, -/* shutdown */ {sIV, sCL, sCW, sCE, sSR, sSS, sSR, sSA, sIV, sSR}, -/* shutdown_ack */ {sIV, sCL, sCW, sCE, sES, sSA, sSA, sSA, sIV, sHA}, -/* error */ {sIV, sCL, sCW, sCL, sES, sSS, sSR, sSA, sIV, sHA}, -/* cookie_echo */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV, sHA},/* Can't come in reply dir */ -/* cookie_ack */ {sIV, sCL, sCW, sES, sES, sSS, sSR, sSA, sIV, sHA}, -/* shutdown_comp*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sCL, sIV, sHA}, -/* heartbeat */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS, sHA}, -/* heartbeat_ack*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHA, sHA} +/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */ +/* init */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV},/* INIT in sCL Big TODO */ +/* init_ack */ {sIV, sCW, sCW, sCE, sES, sSS, sSR, sSA, sIV}, +/* abort */ {sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sIV}, +/* shutdown */ {sIV, sCL, sCW, sCE, sSR, sSS, sSR, sSA, sIV}, +/* shutdown_ack */ {sIV, sCL, sCW, sCE, sES, sSA, sSA, sSA, sIV}, +/* error */ {sIV, sCL, sCW, sCL, sES, sSS, sSR, sSA, sIV}, +/* cookie_echo */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV},/* Can't come in reply dir */ +/* cookie_ack */ {sIV, sCL, sCW, sES, sES, sSS, sSR, sSA, sIV}, +/* shutdown_comp*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sCL, sIV}, +/* heartbeat */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, +/* heartbeat_ack*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sES}, } }; @@ -153,6 +142,7 @@ static void sctp_print_conntrack(struct seq_file *s, struct nf_conn *ct) } #endif +/* do_basic_checks ensures sch->length > 0, do not use before */ #define for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count) \ for ((offset) = (dataoff) + sizeof(struct sctphdr), (count) = 0; \ (offset) < (skb)->len && \ @@ -412,22 +402,29 @@ int nf_conntrack_sctp_packet(struct nf_conn *ct, for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) { /* Special cases of Verification tag check (Sec 8.5.1) */ if (sch->type == SCTP_CID_INIT) { - /* Sec 8.5.1 (A) */ + /* (A) vtag MUST be zero */ if (sh->vtag != 0) goto out_unlock; } else if (sch->type == SCTP_CID_ABORT) { - /* Sec 8.5.1 (B) */ - if (sh->vtag != ct->proto.sctp.vtag[dir] && - sh->vtag != ct->proto.sctp.vtag[!dir]) + /* (B) vtag MUST match own vtag if T flag is unset OR + * MUST match peer's vtag if T flag is set + */ + if ((!(sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[dir]) || + ((sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[!dir])) goto out_unlock; } else if (sch->type == SCTP_CID_SHUTDOWN_COMPLETE) { - /* Sec 8.5.1 (C) */ - if (sh->vtag != ct->proto.sctp.vtag[dir] && - sh->vtag != ct->proto.sctp.vtag[!dir] && - sch->flags & SCTP_CHUNK_FLAG_T) + /* (C) vtag MUST match own vtag if T flag is unset OR + * MUST match peer's vtag if T flag is set + */ + if ((!(sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[dir]) || + ((sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[!dir])) goto out_unlock; } else if (sch->type == SCTP_CID_COOKIE_ECHO) { - /* Sec 8.5.1 (D) */ + /* (D) vtag must be same as init_vtag as found in INIT_ACK */ if (sh->vtag != ct->proto.sctp.vtag[dir]) goto out_unlock; } else if (sch->type == SCTP_CID_HEARTBEAT) { @@ -501,8 +498,12 @@ int nf_conntrack_sctp_packet(struct nf_conn *ct, } ct->proto.sctp.state = new_state; - if (old_state != new_state) + if (old_state != new_state) { nf_conntrack_event_cache(IPCT_PROTOINFO, ct); + if (new_state == SCTP_CONNTRACK_ESTABLISHED && + !test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } } spin_unlock_bh(&ct->lock); @@ -516,14 +517,6 @@ int nf_conntrack_sctp_packet(struct nf_conn *ct, nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[new_state]); - if (old_state == SCTP_CONNTRACK_COOKIE_ECHOED && - dir == IP_CT_DIR_REPLY && - new_state == SCTP_CONNTRACK_ESTABLISHED) { - pr_debug("Setting assured bit\n"); - set_bit(IPS_ASSURED_BIT, &ct->status); - nf_conntrack_event_cache(IPCT_ASSURED, ct); - } - return NF_ACCEPT; out_unlock: diff --git a/net/netfilter/nf_conntrack_proto_tcp.c b/net/netfilter/nf_conntrack_proto_tcp.c index 656631083177..3ac1af6f59fc 100644 --- a/net/netfilter/nf_conntrack_proto_tcp.c +++ b/net/netfilter/nf_conntrack_proto_tcp.c @@ -1068,6 +1068,13 @@ int nf_conntrack_tcp_packet(struct nf_conn *ct, ct->proto.tcp.last_flags |= IP_CT_EXP_CHALLENGE_ACK; } + + /* possible challenge ack reply to syn */ + if (old_state == TCP_CONNTRACK_SYN_SENT && + index == TCP_ACK_SET && + dir == IP_CT_DIR_REPLY) + ct->proto.tcp.last_ack = ntohl(th->ack_seq); + spin_unlock_bh(&ct->lock); nf_ct_l4proto_log_invalid(skb, ct, state, "packet (index %d) in dir %d ignored, state %s", @@ -1193,6 +1200,14 @@ int nf_conntrack_tcp_packet(struct nf_conn *ct, * segments we ignored. */ goto in_window; } + + /* Reset in response to a challenge-ack we let through earlier */ + if (old_state == TCP_CONNTRACK_SYN_SENT && + ct->proto.tcp.last_index == TCP_ACK_SET && + ct->proto.tcp.last_dir == IP_CT_DIR_REPLY && + ntohl(th->seq) == ct->proto.tcp.last_ack) + goto in_window; + break; default: /* Keep compilers happy. */ diff --git a/net/netfilter/nf_conntrack_standalone.c b/net/netfilter/nf_conntrack_standalone.c index 4ffe84c5a82c..460294bd4b60 100644 --- a/net/netfilter/nf_conntrack_standalone.c +++ b/net/netfilter/nf_conntrack_standalone.c @@ -366,7 +366,7 @@ static int ct_seq_show(struct seq_file *s, void *v) goto release; #if defined(CONFIG_NF_CONNTRACK_MARK) - seq_printf(s, "mark=%u ", ct->mark); + seq_printf(s, "mark=%u ", READ_ONCE(ct->mark)); #endif ct_show_secctx(s, ct); @@ -601,7 +601,6 @@ enum nf_ct_sysctl_index { NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_RECD, NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT, NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_SENT, - NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_ACKED, #endif #ifdef CONFIG_NF_CT_PROTO_DCCP NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST, @@ -886,12 +885,6 @@ static struct ctl_table nf_ct_sysctl_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_ACKED] = { - .procname = "nf_conntrack_sctp_timeout_heartbeat_acked", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, #endif #ifdef CONFIG_NF_CT_PROTO_DCCP [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST] = { @@ -1035,7 +1028,6 @@ static void nf_conntrack_standalone_init_sctp_sysctl(struct net *net, XASSIGN(SHUTDOWN_RECD, sn); XASSIGN(SHUTDOWN_ACK_SENT, sn); XASSIGN(HEARTBEAT_SENT, sn); - XASSIGN(HEARTBEAT_ACKED, sn); #undef XASSIGN #endif } diff --git a/net/netfilter/nf_flow_table_ip.c b/net/netfilter/nf_flow_table_ip.c index b350fe9d00b0..19efba1e51ef 100644 --- a/net/netfilter/nf_flow_table_ip.c +++ b/net/netfilter/nf_flow_table_ip.c @@ -421,6 +421,10 @@ nf_flow_offload_ip_hook(void *priv, struct sk_buff *skb, if (ret == NF_DROP) flow_offload_teardown(flow); break; + default: + WARN_ON_ONCE(1); + ret = NF_DROP; + break; } return ret; @@ -682,6 +686,10 @@ nf_flow_offload_ipv6_hook(void *priv, struct sk_buff *skb, if (ret == NF_DROP) flow_offload_teardown(flow); break; + default: + WARN_ON_ONCE(1); + ret = NF_DROP; + break; } return ret; diff --git a/net/netfilter/nf_flow_table_offload.c b/net/netfilter/nf_flow_table_offload.c index b04645ced89b..4d9b99abe37d 100644 --- a/net/netfilter/nf_flow_table_offload.c +++ b/net/netfilter/nf_flow_table_offload.c @@ -383,12 +383,12 @@ static void flow_offload_ipv6_mangle(struct nf_flow_rule *flow_rule, const __be32 *addr, const __be32 *mask) { struct flow_action_entry *entry; - int i, j; + int i; - for (i = 0, j = 0; i < sizeof(struct in6_addr) / sizeof(u32); i += sizeof(u32), j++) { + for (i = 0; i < sizeof(struct in6_addr) / sizeof(u32); i++) { entry = flow_action_entry_next(flow_rule); flow_offload_mangle(entry, FLOW_ACT_MANGLE_HDR_TYPE_IP6, - offset + i, &addr[j], mask); + offset + i * sizeof(u32), &addr[i], mask); } } @@ -997,13 +997,13 @@ static void flow_offload_queue_work(struct flow_offload_work *offload) struct net *net = read_pnet(&offload->flowtable->net); if (offload->cmd == FLOW_CLS_REPLACE) { - NF_FLOW_TABLE_STAT_INC(net, count_wq_add); + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_add); queue_work(nf_flow_offload_add_wq, &offload->work); } else if (offload->cmd == FLOW_CLS_DESTROY) { - NF_FLOW_TABLE_STAT_INC(net, count_wq_del); + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_del); queue_work(nf_flow_offload_del_wq, &offload->work); } else { - NF_FLOW_TABLE_STAT_INC(net, count_wq_stats); + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_stats); queue_work(nf_flow_offload_stats_wq, &offload->work); } } @@ -1098,6 +1098,7 @@ static int nf_flow_table_block_setup(struct nf_flowtable *flowtable, struct flow_block_cb *block_cb, *next; int err = 0; + down_write(&flowtable->flow_block_lock); switch (cmd) { case FLOW_BLOCK_BIND: list_splice(&bo->cb_list, &flowtable->flow_block.cb_list); @@ -1112,6 +1113,7 @@ static int nf_flow_table_block_setup(struct nf_flowtable *flowtable, WARN_ON_ONCE(1); err = -EOPNOTSUPP; } + up_write(&flowtable->flow_block_lock); return err; } @@ -1168,7 +1170,9 @@ static int nf_flow_table_offload_cmd(struct flow_block_offload *bo, nf_flow_table_block_offload_init(bo, dev_net(dev), cmd, flowtable, extack); + down_write(&flowtable->flow_block_lock); err = dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_FT, bo); + up_write(&flowtable->flow_block_lock); if (err < 0) return err; diff --git a/net/netfilter/nf_nat_helper.c b/net/netfilter/nf_nat_helper.c index a95a25196943..bf591e6af005 100644 --- a/net/netfilter/nf_nat_helper.c +++ b/net/netfilter/nf_nat_helper.c @@ -223,7 +223,7 @@ u16 nf_nat_exp_find_port(struct nf_conntrack_expect *exp, u16 port) if (res != -EBUSY || (--attempts_left < 0)) break; - port = min + prandom_u32_max(range); + port = min + get_random_u32_below(range); } return 0; diff --git a/net/netfilter/nf_nat_ovs.c b/net/netfilter/nf_nat_ovs.c new file mode 100644 index 000000000000..551abd2da614 --- /dev/null +++ b/net/netfilter/nf_nat_ovs.c @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Support nat functions for openvswitch and used by OVS and TC conntrack. */ + +#include <net/netfilter/nf_nat.h> + +/* Modelled after nf_nat_ipv[46]_fn(). + * range is only used for new, uninitialized NAT state. + * Returns either NF_ACCEPT or NF_DROP. + */ +static int nf_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, int *action, + const struct nf_nat_range2 *range, + enum nf_nat_manip_type maniptype) +{ + __be16 proto = skb_protocol(skb, true); + int hooknum, err = NF_ACCEPT; + + /* See HOOK2MANIP(). */ + if (maniptype == NF_NAT_MANIP_SRC) + hooknum = NF_INET_LOCAL_IN; /* Source NAT */ + else + hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */ + + switch (ctinfo) { + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + if (proto == htons(ETH_P_IP) && + ip_hdr(skb)->protocol == IPPROTO_ICMP) { + if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, + hooknum)) + err = NF_DROP; + goto out; + } else if (IS_ENABLED(CONFIG_IPV6) && proto == htons(ETH_P_IPV6)) { + __be16 frag_off; + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + int hdrlen = ipv6_skip_exthdr(skb, + sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + + if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { + if (!nf_nat_icmpv6_reply_translation(skb, ct, + ctinfo, + hooknum, + hdrlen)) + err = NF_DROP; + goto out; + } + } + /* Non-ICMP, fall thru to initialize if needed. */ + fallthrough; + case IP_CT_NEW: + /* Seen it before? This can happen for loopback, retrans, + * or local packets. + */ + if (!nf_nat_initialized(ct, maniptype)) { + /* Initialize according to the NAT action. */ + err = (range && range->flags & NF_NAT_RANGE_MAP_IPS) + /* Action is set up to establish a new + * mapping. + */ + ? nf_nat_setup_info(ct, range, maniptype) + : nf_nat_alloc_null_binding(ct, hooknum); + if (err != NF_ACCEPT) + goto out; + } + break; + + case IP_CT_ESTABLISHED: + case IP_CT_ESTABLISHED_REPLY: + break; + + default: + err = NF_DROP; + goto out; + } + + err = nf_nat_packet(ct, ctinfo, hooknum, skb); + if (err == NF_ACCEPT) + *action |= BIT(maniptype); +out: + return err; +} + +int nf_ct_nat(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, int *action, + const struct nf_nat_range2 *range, bool commit) +{ + enum nf_nat_manip_type maniptype; + int err, ct_action = *action; + + *action = 0; + + /* Add NAT extension if not confirmed yet. */ + if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct)) + return NF_DROP; /* Can't NAT. */ + + if (ctinfo != IP_CT_NEW && (ct->status & IPS_NAT_MASK) && + (ctinfo != IP_CT_RELATED || commit)) { + /* NAT an established or related connection like before. */ + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) + /* This is the REPLY direction for a connection + * for which NAT was applied in the forward + * direction. Do the reverse NAT. + */ + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC; + else + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST; + } else if (ct_action & BIT(NF_NAT_MANIP_SRC)) { + maniptype = NF_NAT_MANIP_SRC; + } else if (ct_action & BIT(NF_NAT_MANIP_DST)) { + maniptype = NF_NAT_MANIP_DST; + } else { + return NF_ACCEPT; + } + + err = nf_ct_nat_execute(skb, ct, ctinfo, action, range, maniptype); + if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) { + if (ct->status & IPS_SRC_NAT) { + if (maniptype == NF_NAT_MANIP_SRC) + maniptype = NF_NAT_MANIP_DST; + else + maniptype = NF_NAT_MANIP_SRC; + + err = nf_ct_nat_execute(skb, ct, ctinfo, action, range, + maniptype); + } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { + err = nf_ct_nat_execute(skb, ct, ctinfo, action, NULL, + NF_NAT_MANIP_SRC); + } + } + return err; +} +EXPORT_SYMBOL_GPL(nf_ct_nat); diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index e7152d599d73..8c09e4d12ac1 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -465,8 +465,9 @@ static int nft_delrule_by_chain(struct nft_ctx *ctx) return 0; } -static int nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, - struct nft_set *set) +static int __nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, + struct nft_set *set, + const struct nft_set_desc *desc) { struct nft_trans *trans; @@ -474,17 +475,28 @@ static int nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, if (trans == NULL) return -ENOMEM; - if (msg_type == NFT_MSG_NEWSET && ctx->nla[NFTA_SET_ID] != NULL) { + if (msg_type == NFT_MSG_NEWSET && ctx->nla[NFTA_SET_ID] && !desc) { nft_trans_set_id(trans) = ntohl(nla_get_be32(ctx->nla[NFTA_SET_ID])); nft_activate_next(ctx->net, set); } nft_trans_set(trans) = set; + if (desc) { + nft_trans_set_update(trans) = true; + nft_trans_set_gc_int(trans) = desc->gc_int; + nft_trans_set_timeout(trans) = desc->timeout; + } nft_trans_commit_list_add_tail(ctx->net, trans); return 0; } +static int nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, + struct nft_set *set) +{ + return __nft_trans_set_add(ctx, msg_type, set, NULL); +} + static int nft_delset(const struct nft_ctx *ctx, struct nft_set *set) { int err; @@ -1534,10 +1546,10 @@ static int nft_dump_stats(struct sk_buff *skb, struct nft_stats __percpu *stats) for_each_possible_cpu(cpu) { cpu_stats = per_cpu_ptr(stats, cpu); do { - seq = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + seq = u64_stats_fetch_begin(&cpu_stats->syncp); pkts = cpu_stats->pkts; bytes = cpu_stats->bytes; - } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, seq)); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, seq)); total.pkts += pkts; total.bytes += bytes; } @@ -2759,7 +2771,7 @@ static const struct nla_policy nft_expr_policy[NFTA_EXPR_MAX + 1] = { }; static int nf_tables_fill_expr_info(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { if (nla_put_string(skb, NFTA_EXPR_NAME, expr->ops->type->name)) goto nla_put_failure; @@ -2769,7 +2781,7 @@ static int nf_tables_fill_expr_info(struct sk_buff *skb, NFTA_EXPR_DATA); if (data == NULL) goto nla_put_failure; - if (expr->ops->dump(skb, expr) < 0) + if (expr->ops->dump(skb, expr, reset) < 0) goto nla_put_failure; nla_nest_end(skb, data); } @@ -2781,14 +2793,14 @@ nla_put_failure: }; int nft_expr_dump(struct sk_buff *skb, unsigned int attr, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { struct nlattr *nest; nest = nla_nest_start_noflag(skb, attr); if (!nest) goto nla_put_failure; - if (nf_tables_fill_expr_info(skb, expr) < 0) + if (nf_tables_fill_expr_info(skb, expr, reset) < 0) goto nla_put_failure; nla_nest_end(skb, nest); return 0; @@ -2857,6 +2869,43 @@ err1: return err; } +int nft_expr_inner_parse(const struct nft_ctx *ctx, const struct nlattr *nla, + struct nft_expr_info *info) +{ + struct nlattr *tb[NFTA_EXPR_MAX + 1]; + const struct nft_expr_type *type; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_EXPR_MAX, nla, + nft_expr_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_EXPR_DATA]) + return -EINVAL; + + type = __nft_expr_type_get(ctx->family, tb[NFTA_EXPR_NAME]); + if (!type) + return -ENOENT; + + if (!type->inner_ops) + return -EOPNOTSUPP; + + err = nla_parse_nested_deprecated(info->tb, type->maxattr, + tb[NFTA_EXPR_DATA], + type->policy, NULL); + if (err < 0) + goto err_nla_parse; + + info->attr = nla; + info->ops = type->inner_ops; + + return 0; + +err_nla_parse: + return err; +} + static int nf_tables_newexpr(const struct nft_ctx *ctx, const struct nft_expr_info *expr_info, struct nft_expr *expr) @@ -2997,7 +3046,8 @@ static int nf_tables_fill_rule_info(struct sk_buff *skb, struct net *net, u32 flags, int family, const struct nft_table *table, const struct nft_chain *chain, - const struct nft_rule *rule, u64 handle) + const struct nft_rule *rule, u64 handle, + bool reset) { struct nlmsghdr *nlh; const struct nft_expr *expr, *next; @@ -3030,7 +3080,7 @@ static int nf_tables_fill_rule_info(struct sk_buff *skb, struct net *net, if (list == NULL) goto nla_put_failure; nft_rule_for_each_expr(expr, next, rule) { - if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr) < 0) + if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr, reset) < 0) goto nla_put_failure; } nla_nest_end(skb, list); @@ -3081,7 +3131,7 @@ static void nf_tables_rule_notify(const struct nft_ctx *ctx, err = nf_tables_fill_rule_info(skb, ctx->net, ctx->portid, ctx->seq, event, flags, ctx->family, ctx->table, - ctx->chain, rule, handle); + ctx->chain, rule, handle, false); if (err < 0) { kfree_skb(skb); goto err; @@ -3102,7 +3152,8 @@ static int __nf_tables_dump_rules(struct sk_buff *skb, unsigned int *idx, struct netlink_callback *cb, const struct nft_table *table, - const struct nft_chain *chain) + const struct nft_chain *chain, + bool reset) { struct net *net = sock_net(skb->sk); const struct nft_rule *rule, *prule; @@ -3129,7 +3180,7 @@ static int __nf_tables_dump_rules(struct sk_buff *skb, NFT_MSG_NEWRULE, NLM_F_MULTI | NLM_F_APPEND, table->family, - table, chain, rule, handle) < 0) + table, chain, rule, handle, reset) < 0) return 1; nl_dump_check_consistent(cb, nlmsg_hdr(skb)); @@ -3152,6 +3203,10 @@ static int nf_tables_dump_rules(struct sk_buff *skb, struct net *net = sock_net(skb->sk); int family = nfmsg->nfgen_family; struct nftables_pernet *nft_net; + bool reset = false; + + if (NFNL_MSG_TYPE(cb->nlh->nlmsg_type) == NFT_MSG_GETRULE_RESET) + reset = true; rcu_read_lock(); nft_net = nft_pernet(net); @@ -3176,14 +3231,15 @@ static int nf_tables_dump_rules(struct sk_buff *skb, if (!nft_is_active(net, chain)) continue; __nf_tables_dump_rules(skb, &idx, - cb, table, chain); + cb, table, chain, reset); break; } goto done; } list_for_each_entry_rcu(chain, &table->chains, list) { - if (__nf_tables_dump_rules(skb, &idx, cb, table, chain)) + if (__nf_tables_dump_rules(skb, &idx, + cb, table, chain, reset)) goto done; } @@ -3254,6 +3310,7 @@ static int nf_tables_getrule(struct sk_buff *skb, const struct nfnl_info *info, struct net *net = info->net; struct nft_table *table; struct sk_buff *skb2; + bool reset = false; int err; if (info->nlh->nlmsg_flags & NLM_F_DUMP) { @@ -3290,9 +3347,12 @@ static int nf_tables_getrule(struct sk_buff *skb, const struct nfnl_info *info, if (!skb2) return -ENOMEM; + if (NFNL_MSG_TYPE(info->nlh->nlmsg_type) == NFT_MSG_GETRULE_RESET) + reset = true; + err = nf_tables_fill_rule_info(skb2, net, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, NFT_MSG_NEWRULE, 0, - family, table, chain, rule, 0); + family, table, chain, rule, 0, reset); if (err < 0) goto err_fill_rule_info; @@ -3732,8 +3792,7 @@ static bool nft_set_ops_candidate(const struct nft_set_type *type, u32 flags) static const struct nft_set_ops * nft_select_set_ops(const struct nft_ctx *ctx, const struct nlattr * const nla[], - const struct nft_set_desc *desc, - enum nft_set_policies policy) + const struct nft_set_desc *desc) { struct nftables_pernet *nft_net = nft_pernet(ctx->net); const struct nft_set_ops *ops, *bops; @@ -3762,7 +3821,7 @@ nft_select_set_ops(const struct nft_ctx *ctx, if (!ops->estimate(desc, flags, &est)) continue; - switch (policy) { + switch (desc->policy) { case NFT_SET_POL_PERFORMANCE: if (est.lookup < best.lookup) break; @@ -3997,8 +4056,10 @@ static int nf_tables_fill_set_concat(struct sk_buff *skb, static int nf_tables_fill_set(struct sk_buff *skb, const struct nft_ctx *ctx, const struct nft_set *set, u16 event, u16 flags) { - struct nlmsghdr *nlh; + u64 timeout = READ_ONCE(set->timeout); + u32 gc_int = READ_ONCE(set->gc_int); u32 portid = ctx->portid; + struct nlmsghdr *nlh; struct nlattr *nest; u32 seq = ctx->seq; int i; @@ -4034,13 +4095,13 @@ static int nf_tables_fill_set(struct sk_buff *skb, const struct nft_ctx *ctx, nla_put_be32(skb, NFTA_SET_OBJ_TYPE, htonl(set->objtype))) goto nla_put_failure; - if (set->timeout && + if (timeout && nla_put_be64(skb, NFTA_SET_TIMEOUT, - nf_jiffies64_to_msecs(set->timeout), + nf_jiffies64_to_msecs(timeout), NFTA_SET_PAD)) goto nla_put_failure; - if (set->gc_int && - nla_put_be32(skb, NFTA_SET_GC_INTERVAL, htonl(set->gc_int))) + if (gc_int && + nla_put_be32(skb, NFTA_SET_GC_INTERVAL, htonl(gc_int))) goto nla_put_failure; if (set->policy != NFT_SET_POL_PERFORMANCE) { @@ -4067,7 +4128,7 @@ static int nf_tables_fill_set(struct sk_buff *skb, const struct nft_ctx *ctx, if (set->num_exprs == 1) { nest = nla_nest_start_noflag(skb, NFTA_SET_EXPR); - if (nf_tables_fill_expr_info(skb, set->exprs[0]) < 0) + if (nf_tables_fill_expr_info(skb, set->exprs[0], false) < 0) goto nla_put_failure; nla_nest_end(skb, nest); @@ -4078,7 +4139,7 @@ static int nf_tables_fill_set(struct sk_buff *skb, const struct nft_ctx *ctx, for (i = 0; i < set->num_exprs; i++) { if (nft_expr_dump(skb, NFTA_LIST_ELEM, - set->exprs[i]) < 0) + set->exprs[i], false) < 0) goto nla_put_failure; } nla_nest_end(skb, nest); @@ -4341,15 +4402,94 @@ static int nf_tables_set_desc_parse(struct nft_set_desc *desc, return err; } +static int nft_set_expr_alloc(struct nft_ctx *ctx, struct nft_set *set, + const struct nlattr * const *nla, + struct nft_expr **exprs, int *num_exprs, + u32 flags) +{ + struct nft_expr *expr; + int err, i; + + if (nla[NFTA_SET_EXPR]) { + expr = nft_set_elem_expr_alloc(ctx, set, nla[NFTA_SET_EXPR]); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_expr_alloc; + } + exprs[0] = expr; + (*num_exprs)++; + } else if (nla[NFTA_SET_EXPRESSIONS]) { + struct nlattr *tmp; + int left; + + if (!(flags & NFT_SET_EXPR)) { + err = -EINVAL; + goto err_set_expr_alloc; + } + i = 0; + nla_for_each_nested(tmp, nla[NFTA_SET_EXPRESSIONS], left) { + if (i == NFT_SET_EXPR_MAX) { + err = -E2BIG; + goto err_set_expr_alloc; + } + if (nla_type(tmp) != NFTA_LIST_ELEM) { + err = -EINVAL; + goto err_set_expr_alloc; + } + expr = nft_set_elem_expr_alloc(ctx, set, tmp); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_expr_alloc; + } + exprs[i++] = expr; + (*num_exprs)++; + } + } + + return 0; + +err_set_expr_alloc: + for (i = 0; i < *num_exprs; i++) + nft_expr_destroy(ctx, exprs[i]); + + return err; +} + +static bool nft_set_is_same(const struct nft_set *set, + const struct nft_set_desc *desc, + struct nft_expr *exprs[], u32 num_exprs, u32 flags) +{ + int i; + + if (set->ktype != desc->ktype || + set->dtype != desc->dtype || + set->flags != flags || + set->klen != desc->klen || + set->dlen != desc->dlen || + set->field_count != desc->field_count || + set->num_exprs != num_exprs) + return false; + + for (i = 0; i < desc->field_count; i++) { + if (set->field_len[i] != desc->field_len[i]) + return false; + } + + for (i = 0; i < num_exprs; i++) { + if (set->exprs[i]->ops != exprs[i]->ops) + return false; + } + + return true; +} + static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, const struct nlattr * const nla[]) { - u32 ktype, dtype, flags, policy, gc_int, objtype; struct netlink_ext_ack *extack = info->extack; u8 genmask = nft_genmask_next(info->net); u8 family = info->nfmsg->nfgen_family; const struct nft_set_ops *ops; - struct nft_expr *expr = NULL; struct net *net = info->net; struct nft_set_desc desc; struct nft_table *table; @@ -4357,10 +4497,11 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, struct nft_set *set; struct nft_ctx ctx; size_t alloc_size; - u64 timeout; + int num_exprs = 0; char *name; int err, i; u16 udlen; + u32 flags; u64 size; if (nla[NFTA_SET_TABLE] == NULL || @@ -4371,10 +4512,10 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, memset(&desc, 0, sizeof(desc)); - ktype = NFT_DATA_VALUE; + desc.ktype = NFT_DATA_VALUE; if (nla[NFTA_SET_KEY_TYPE] != NULL) { - ktype = ntohl(nla_get_be32(nla[NFTA_SET_KEY_TYPE])); - if ((ktype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK) + desc.ktype = ntohl(nla_get_be32(nla[NFTA_SET_KEY_TYPE])); + if ((desc.ktype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK) return -EINVAL; } @@ -4399,17 +4540,17 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, return -EOPNOTSUPP; } - dtype = 0; + desc.dtype = 0; if (nla[NFTA_SET_DATA_TYPE] != NULL) { if (!(flags & NFT_SET_MAP)) return -EINVAL; - dtype = ntohl(nla_get_be32(nla[NFTA_SET_DATA_TYPE])); - if ((dtype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK && - dtype != NFT_DATA_VERDICT) + desc.dtype = ntohl(nla_get_be32(nla[NFTA_SET_DATA_TYPE])); + if ((desc.dtype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK && + desc.dtype != NFT_DATA_VERDICT) return -EINVAL; - if (dtype != NFT_DATA_VERDICT) { + if (desc.dtype != NFT_DATA_VERDICT) { if (nla[NFTA_SET_DATA_LEN] == NULL) return -EINVAL; desc.dlen = ntohl(nla_get_be32(nla[NFTA_SET_DATA_LEN])); @@ -4424,34 +4565,34 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, if (!(flags & NFT_SET_OBJECT)) return -EINVAL; - objtype = ntohl(nla_get_be32(nla[NFTA_SET_OBJ_TYPE])); - if (objtype == NFT_OBJECT_UNSPEC || - objtype > NFT_OBJECT_MAX) + desc.objtype = ntohl(nla_get_be32(nla[NFTA_SET_OBJ_TYPE])); + if (desc.objtype == NFT_OBJECT_UNSPEC || + desc.objtype > NFT_OBJECT_MAX) return -EOPNOTSUPP; } else if (flags & NFT_SET_OBJECT) return -EINVAL; else - objtype = NFT_OBJECT_UNSPEC; + desc.objtype = NFT_OBJECT_UNSPEC; - timeout = 0; + desc.timeout = 0; if (nla[NFTA_SET_TIMEOUT] != NULL) { if (!(flags & NFT_SET_TIMEOUT)) return -EINVAL; - err = nf_msecs_to_jiffies64(nla[NFTA_SET_TIMEOUT], &timeout); + err = nf_msecs_to_jiffies64(nla[NFTA_SET_TIMEOUT], &desc.timeout); if (err) return err; } - gc_int = 0; + desc.gc_int = 0; if (nla[NFTA_SET_GC_INTERVAL] != NULL) { if (!(flags & NFT_SET_TIMEOUT)) return -EINVAL; - gc_int = ntohl(nla_get_be32(nla[NFTA_SET_GC_INTERVAL])); + desc.gc_int = ntohl(nla_get_be32(nla[NFTA_SET_GC_INTERVAL])); } - policy = NFT_SET_POL_PERFORMANCE; + desc.policy = NFT_SET_POL_PERFORMANCE; if (nla[NFTA_SET_POLICY] != NULL) - policy = ntohl(nla_get_be32(nla[NFTA_SET_POLICY])); + desc.policy = ntohl(nla_get_be32(nla[NFTA_SET_POLICY])); if (nla[NFTA_SET_DESC] != NULL) { err = nf_tables_set_desc_parse(&desc, nla[NFTA_SET_DESC]); @@ -4483,6 +4624,8 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, return PTR_ERR(set); } } else { + struct nft_expr *exprs[NFT_SET_EXPR_MAX] = {}; + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { NL_SET_BAD_ATTR(extack, nla[NFTA_SET_NAME]); return -EEXIST; @@ -4490,13 +4633,29 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, if (info->nlh->nlmsg_flags & NLM_F_REPLACE) return -EOPNOTSUPP; - return 0; + err = nft_set_expr_alloc(&ctx, set, nla, exprs, &num_exprs, flags); + if (err < 0) + return err; + + err = 0; + if (!nft_set_is_same(set, &desc, exprs, num_exprs, flags)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_NAME]); + err = -EEXIST; + } + + for (i = 0; i < num_exprs; i++) + nft_expr_destroy(&ctx, exprs[i]); + + if (err < 0) + return err; + + return __nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set, &desc); } if (!(info->nlh->nlmsg_flags & NLM_F_CREATE)) return -ENOENT; - ops = nft_select_set_ops(&ctx, nla, &desc, policy); + ops = nft_select_set_ops(&ctx, nla, &desc); if (IS_ERR(ops)) return PTR_ERR(ops); @@ -4536,18 +4695,18 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, set->table = table; write_pnet(&set->net, net); set->ops = ops; - set->ktype = ktype; + set->ktype = desc.ktype; set->klen = desc.klen; - set->dtype = dtype; - set->objtype = objtype; + set->dtype = desc.dtype; + set->objtype = desc.objtype; set->dlen = desc.dlen; set->flags = flags; set->size = desc.size; - set->policy = policy; + set->policy = desc.policy; set->udlen = udlen; set->udata = udata; - set->timeout = timeout; - set->gc_int = gc_int; + set->timeout = desc.timeout; + set->gc_int = desc.gc_int; set->field_count = desc.field_count; for (i = 0; i < desc.field_count; i++) @@ -4557,43 +4716,11 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, if (err < 0) goto err_set_init; - if (nla[NFTA_SET_EXPR]) { - expr = nft_set_elem_expr_alloc(&ctx, set, nla[NFTA_SET_EXPR]); - if (IS_ERR(expr)) { - err = PTR_ERR(expr); - goto err_set_expr_alloc; - } - set->exprs[0] = expr; - set->num_exprs++; - } else if (nla[NFTA_SET_EXPRESSIONS]) { - struct nft_expr *expr; - struct nlattr *tmp; - int left; - - if (!(flags & NFT_SET_EXPR)) { - err = -EINVAL; - goto err_set_expr_alloc; - } - i = 0; - nla_for_each_nested(tmp, nla[NFTA_SET_EXPRESSIONS], left) { - if (i == NFT_SET_EXPR_MAX) { - err = -E2BIG; - goto err_set_expr_alloc; - } - if (nla_type(tmp) != NFTA_LIST_ELEM) { - err = -EINVAL; - goto err_set_expr_alloc; - } - expr = nft_set_elem_expr_alloc(&ctx, set, tmp); - if (IS_ERR(expr)) { - err = PTR_ERR(expr); - goto err_set_expr_alloc; - } - set->exprs[i++] = expr; - set->num_exprs++; - } - } + err = nft_set_expr_alloc(&ctx, set, nla, set->exprs, &num_exprs, flags); + if (err < 0) + goto err_set_destroy; + set->num_exprs = num_exprs; set->handle = nf_tables_alloc_handle(table); err = nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set); @@ -4607,7 +4734,7 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, err_set_expr_alloc: for (i = 0; i < set->num_exprs; i++) nft_expr_destroy(&ctx, set->exprs[i]); - +err_set_destroy: ops->destroy(set); err_set_init: kfree(set->name); @@ -4909,7 +5036,7 @@ static int nft_set_elem_expr_dump(struct sk_buff *skb, if (num_exprs == 1) { expr = nft_setelem_expr_at(elem_expr, 0); - if (nft_expr_dump(skb, NFTA_SET_ELEM_EXPR, expr) < 0) + if (nft_expr_dump(skb, NFTA_SET_ELEM_EXPR, expr, false) < 0) return -1; return 0; @@ -4920,7 +5047,7 @@ static int nft_set_elem_expr_dump(struct sk_buff *skb, nft_setelem_expr_foreach(expr, elem_expr, size) { expr = nft_setelem_expr_at(elem_expr, size); - if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr) < 0) + if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr, false) < 0) goto nla_put_failure; } nla_nest_end(skb, nest); @@ -5958,8 +6085,9 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, &timeout); if (err) return err; - } else if (set->flags & NFT_SET_TIMEOUT) { - timeout = set->timeout; + } else if (set->flags & NFT_SET_TIMEOUT && + !(flags & NFT_SET_ELEM_INTERVAL_END)) { + timeout = READ_ONCE(set->timeout); } expiration = 0; @@ -6024,7 +6152,8 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, err = -EOPNOTSUPP; goto err_set_elem_expr; } - } else if (set->num_exprs > 0) { + } else if (set->num_exprs > 0 && + !(flags & NFT_SET_ELEM_INTERVAL_END)) { err = nft_set_elem_expr_clone(ctx, set, expr_array); if (err < 0) goto err_set_elem_expr_clone; @@ -6059,7 +6188,7 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, if (err < 0) goto err_parse_key_end; - if (timeout != set->timeout) { + if (timeout != READ_ONCE(set->timeout)) { err = nft_set_ext_add(&tmpl, NFT_SET_EXT_TIMEOUT); if (err < 0) goto err_parse_key_end; @@ -8274,6 +8403,12 @@ static const struct nfnl_callback nf_tables_cb[NFT_MSG_MAX] = { .attr_count = NFTA_RULE_MAX, .policy = nft_rule_policy, }, + [NFT_MSG_GETRULE_RESET] = { + .call = nf_tables_getrule, + .type = NFNL_CB_RCU, + .attr_count = NFTA_RULE_MAX, + .policy = nft_rule_policy, + }, [NFT_MSG_DELRULE] = { .call = nf_tables_delrule, .type = NFNL_CB_BATCH, @@ -8975,14 +9110,20 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb) nft_flow_rule_destroy(nft_trans_flow_rule(trans)); break; case NFT_MSG_NEWSET: - nft_clear(net, nft_trans_set(trans)); - /* This avoids hitting -EBUSY when deleting the table - * from the transaction. - */ - if (nft_set_is_anonymous(nft_trans_set(trans)) && - !list_empty(&nft_trans_set(trans)->bindings)) - trans->ctx.table->use--; + if (nft_trans_set_update(trans)) { + struct nft_set *set = nft_trans_set(trans); + WRITE_ONCE(set->timeout, nft_trans_set_timeout(trans)); + WRITE_ONCE(set->gc_int, nft_trans_set_gc_int(trans)); + } else { + nft_clear(net, nft_trans_set(trans)); + /* This avoids hitting -EBUSY when deleting the table + * from the transaction. + */ + if (nft_set_is_anonymous(nft_trans_set(trans)) && + !list_empty(&nft_trans_set(trans)->bindings)) + trans->ctx.table->use--; + } nf_tables_set_notify(&trans->ctx, nft_trans_set(trans), NFT_MSG_NEWSET, GFP_KERNEL); nft_trans_destroy(trans); @@ -9204,6 +9345,10 @@ static int __nf_tables_abort(struct net *net, enum nfnl_abort_action action) nft_trans_destroy(trans); break; case NFT_MSG_NEWSET: + if (nft_trans_set_update(trans)) { + nft_trans_destroy(trans); + break; + } trans->ctx.table->use--; if (nft_trans_set_bound(trans)) { nft_trans_destroy(trans); diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c index cee3e4e905ec..709a736c301c 100644 --- a/net/netfilter/nf_tables_core.c +++ b/net/netfilter/nf_tables_core.c @@ -340,6 +340,8 @@ static struct nft_expr_type *nft_basic_types[] = { &nft_exthdr_type, &nft_last_type, &nft_counter_type, + &nft_objref_type, + &nft_inner_type, }; static struct nft_object_type *nft_basic_objects[] = { diff --git a/net/netfilter/nft_bitwise.c b/net/netfilter/nft_bitwise.c index e6e402b247d0..84eae7cabc67 100644 --- a/net/netfilter/nft_bitwise.c +++ b/net/netfilter/nft_bitwise.c @@ -232,7 +232,8 @@ static int nft_bitwise_dump_shift(struct sk_buff *skb, return 0; } -static int nft_bitwise_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_bitwise_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_bitwise *priv = nft_expr_priv(expr); int err = 0; @@ -393,7 +394,8 @@ static int nft_bitwise_fast_init(const struct nft_ctx *ctx, } static int -nft_bitwise_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +nft_bitwise_fast_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); struct nft_data data; diff --git a/net/netfilter/nft_byteorder.c b/net/netfilter/nft_byteorder.c index f952a80275a8..b66647a5a171 100644 --- a/net/netfilter/nft_byteorder.c +++ b/net/netfilter/nft_byteorder.c @@ -148,7 +148,8 @@ static int nft_byteorder_init(const struct nft_ctx *ctx, priv->len); } -static int nft_byteorder_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_byteorder_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_byteorder *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_cmp.c b/net/netfilter/nft_cmp.c index 963cf831799c..6eb21a4f5698 100644 --- a/net/netfilter/nft_cmp.c +++ b/net/netfilter/nft_cmp.c @@ -92,7 +92,8 @@ static int nft_cmp_init(const struct nft_ctx *ctx, const struct nft_expr *expr, return 0; } -static int nft_cmp_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_cmp_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_cmp_expr *priv = nft_expr_priv(expr); @@ -253,7 +254,8 @@ static int nft_cmp_fast_offload(struct nft_offload_ctx *ctx, return __nft_cmp_offload(ctx, flow, &cmp); } -static int nft_cmp_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_cmp_fast_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_cmp_fast_expr *priv = nft_expr_priv(expr); enum nft_cmp_ops op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ; @@ -347,7 +349,8 @@ static int nft_cmp16_fast_offload(struct nft_offload_ctx *ctx, return __nft_cmp_offload(ctx, flow, &cmp); } -static int nft_cmp16_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_cmp16_fast_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_cmp16_fast_expr *priv = nft_expr_priv(expr); enum nft_cmp_ops op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ; diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c index c16172427622..5284cd2ad532 100644 --- a/net/netfilter/nft_compat.c +++ b/net/netfilter/nft_compat.c @@ -324,7 +324,8 @@ static int nft_extension_dump_info(struct sk_buff *skb, int attr, return 0; } -static int nft_target_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_target_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct xt_target *target = expr->ops->data; void *info = nft_expr_priv(expr); @@ -572,12 +573,14 @@ nla_put_failure: return -1; } -static int nft_match_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_match_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { return __nft_match_dump(skb, expr, nft_expr_priv(expr)); } -static int nft_match_large_dump(struct sk_buff *skb, const struct nft_expr *e) +static int nft_match_large_dump(struct sk_buff *skb, + const struct nft_expr *e, bool reset) { struct nft_xt_match_priv *priv = nft_expr_priv(e); diff --git a/net/netfilter/nft_connlimit.c b/net/netfilter/nft_connlimit.c index d657f999a11b..de9d1980df69 100644 --- a/net/netfilter/nft_connlimit.c +++ b/net/netfilter/nft_connlimit.c @@ -185,7 +185,8 @@ static void nft_connlimit_eval(const struct nft_expr *expr, nft_connlimit_do_eval(priv, regs, pkt, NULL); } -static int nft_connlimit_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_connlimit_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_connlimit *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_counter.c b/net/netfilter/nft_counter.c index f4d3573e8782..dccc68a5135a 100644 --- a/net/netfilter/nft_counter.c +++ b/net/netfilter/nft_counter.c @@ -201,11 +201,12 @@ void nft_counter_eval(const struct nft_expr *expr, struct nft_regs *regs, nft_counter_do_eval(priv, regs, pkt); } -static int nft_counter_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_counter_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); - return nft_counter_do_dump(skb, priv, false); + return nft_counter_do_dump(skb, priv, reset); } static int nft_counter_init(const struct nft_ctx *ctx, diff --git a/net/netfilter/nft_ct.c b/net/netfilter/nft_ct.c index a3f01f209a53..c68e2151defe 100644 --- a/net/netfilter/nft_ct.c +++ b/net/netfilter/nft_ct.c @@ -98,7 +98,7 @@ static void nft_ct_get_eval(const struct nft_expr *expr, return; #ifdef CONFIG_NF_CONNTRACK_MARK case NFT_CT_MARK: - *dest = ct->mark; + *dest = READ_ONCE(ct->mark); return; #endif #ifdef CONFIG_NF_CONNTRACK_SECMARK @@ -297,8 +297,8 @@ static void nft_ct_set_eval(const struct nft_expr *expr, switch (priv->key) { #ifdef CONFIG_NF_CONNTRACK_MARK case NFT_CT_MARK: - if (ct->mark != value) { - ct->mark = value; + if (READ_ONCE(ct->mark) != value) { + WRITE_ONCE(ct->mark, value); nf_conntrack_event_cache(IPCT_MARK, ct); } break; @@ -641,7 +641,8 @@ static void nft_ct_set_destroy(const struct nft_ctx *ctx, nf_ct_netns_put(ctx->net, ctx->family); } -static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_ct_get_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_ct *priv = nft_expr_priv(expr); @@ -703,7 +704,8 @@ static bool nft_ct_get_reduce(struct nft_regs_track *track, return nft_expr_reduce_bitwise(track, expr); } -static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_ct_set_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_ct *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_dup_netdev.c b/net/netfilter/nft_dup_netdev.c index 63507402716d..e5739a59ebf1 100644 --- a/net/netfilter/nft_dup_netdev.c +++ b/net/netfilter/nft_dup_netdev.c @@ -44,7 +44,8 @@ static int nft_dup_netdev_init(const struct nft_ctx *ctx, sizeof(int)); } -static int nft_dup_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_dup_netdev_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_dup_netdev *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_dynset.c b/net/netfilter/nft_dynset.c index 6983e6ddeef9..274579b1696e 100644 --- a/net/netfilter/nft_dynset.c +++ b/net/netfilter/nft_dynset.c @@ -357,7 +357,8 @@ static void nft_dynset_destroy(const struct nft_ctx *ctx, nf_tables_destroy_set(ctx, priv->set); } -static int nft_dynset_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_dynset_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_dynset *priv = nft_expr_priv(expr); u32 flags = priv->invert ? NFT_DYNSET_F_INV : 0; @@ -379,7 +380,7 @@ static int nft_dynset_dump(struct sk_buff *skb, const struct nft_expr *expr) if (priv->set->num_exprs == 0) { if (priv->num_exprs == 1) { if (nft_expr_dump(skb, NFTA_DYNSET_EXPR, - priv->expr_array[0])) + priv->expr_array[0], reset)) goto nla_put_failure; } else if (priv->num_exprs > 1) { struct nlattr *nest; @@ -390,7 +391,7 @@ static int nft_dynset_dump(struct sk_buff *skb, const struct nft_expr *expr) for (i = 0; i < priv->num_exprs; i++) { if (nft_expr_dump(skb, NFTA_LIST_ELEM, - priv->expr_array[i])) + priv->expr_array[i], reset)) goto nla_put_failure; } nla_nest_end(skb, nest); diff --git a/net/netfilter/nft_exthdr.c b/net/netfilter/nft_exthdr.c index a67ea9c3ae57..a54a7f772cec 100644 --- a/net/netfilter/nft_exthdr.c +++ b/net/netfilter/nft_exthdr.c @@ -13,7 +13,6 @@ #include <linux/sctp.h> #include <net/netfilter/nf_tables_core.h> #include <net/netfilter/nf_tables.h> -#include <net/sctp/sctp.h> #include <net/tcp.h> struct nft_exthdr { @@ -576,7 +575,8 @@ nla_put_failure: return -1; } -static int nft_exthdr_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_exthdr_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_exthdr *priv = nft_expr_priv(expr); @@ -586,7 +586,8 @@ static int nft_exthdr_dump(struct sk_buff *skb, const struct nft_expr *expr) return nft_exthdr_dump_common(skb, priv); } -static int nft_exthdr_dump_set(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_exthdr_dump_set(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_exthdr *priv = nft_expr_priv(expr); @@ -596,7 +597,8 @@ static int nft_exthdr_dump_set(struct sk_buff *skb, const struct nft_expr *expr) return nft_exthdr_dump_common(skb, priv); } -static int nft_exthdr_dump_strip(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_exthdr_dump_strip(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_exthdr *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_fib.c b/net/netfilter/nft_fib.c index 1f12d7ade606..6e049fd48760 100644 --- a/net/netfilter/nft_fib.c +++ b/net/netfilter/nft_fib.c @@ -118,7 +118,7 @@ int nft_fib_init(const struct nft_ctx *ctx, const struct nft_expr *expr, } EXPORT_SYMBOL_GPL(nft_fib_init); -int nft_fib_dump(struct sk_buff *skb, const struct nft_expr *expr) +int nft_fib_dump(struct sk_buff *skb, const struct nft_expr *expr, bool reset) { const struct nft_fib *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_flow_offload.c b/net/netfilter/nft_flow_offload.c index a25c88bc8b75..e860d8fe0e5e 100644 --- a/net/netfilter/nft_flow_offload.c +++ b/net/netfilter/nft_flow_offload.c @@ -433,7 +433,8 @@ static void nft_flow_offload_destroy(const struct nft_ctx *ctx, nf_ct_netns_put(ctx->net, ctx->family); } -static int nft_flow_offload_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_flow_offload_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_flow_offload *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_fwd_netdev.c b/net/netfilter/nft_fwd_netdev.c index 7c5876dc9ff2..7b9d4d1bd17c 100644 --- a/net/netfilter/nft_fwd_netdev.c +++ b/net/netfilter/nft_fwd_netdev.c @@ -56,7 +56,8 @@ static int nft_fwd_netdev_init(const struct nft_ctx *ctx, sizeof(int)); } -static int nft_fwd_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_fwd_netdev_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_fwd_netdev *priv = nft_expr_priv(expr); @@ -186,7 +187,8 @@ static int nft_fwd_neigh_init(const struct nft_ctx *ctx, addr_len); } -static int nft_fwd_neigh_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_fwd_neigh_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_fwd_neigh *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c index e5631e88b285..ee8d487b69c0 100644 --- a/net/netfilter/nft_hash.c +++ b/net/netfilter/nft_hash.c @@ -139,7 +139,7 @@ static int nft_symhash_init(const struct nft_ctx *ctx, } static int nft_jhash_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_jhash *priv = nft_expr_priv(expr); @@ -176,7 +176,7 @@ static bool nft_jhash_reduce(struct nft_regs_track *track, } static int nft_symhash_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_symhash *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_immediate.c b/net/netfilter/nft_immediate.c index 5f28b21abc7d..c9d2f7c29f53 100644 --- a/net/netfilter/nft_immediate.c +++ b/net/netfilter/nft_immediate.c @@ -147,7 +147,8 @@ static void nft_immediate_destroy(const struct nft_ctx *ctx, } } -static int nft_immediate_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_immediate_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_immediate_expr *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_inner.c b/net/netfilter/nft_inner.c new file mode 100644 index 000000000000..28e2873ba24e --- /dev/null +++ b/net/netfilter/nft_inner.c @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2022 Pablo Neira Ayuso <pablo@netfilter.org> + */ + +#include <linux/kernel.h> +#include <linux/if_vlan.h> +#include <linux/init.h> +#include <linux/module.h> +#include <linux/netlink.h> +#include <linux/netfilter.h> +#include <linux/netfilter/nf_tables.h> +#include <net/netfilter/nf_tables_core.h> +#include <net/netfilter/nf_tables.h> +#include <net/netfilter/nft_meta.h> +#include <net/netfilter/nf_tables_offload.h> +#include <linux/tcp.h> +#include <linux/udp.h> +#include <net/gre.h> +#include <net/geneve.h> +#include <net/ip.h> +#include <linux/icmpv6.h> +#include <linux/ip.h> +#include <linux/ipv6.h> + +static DEFINE_PER_CPU(struct nft_inner_tun_ctx, nft_pcpu_tun_ctx); + +/* Same layout as nft_expr but it embeds the private expression data area. */ +struct __nft_expr { + const struct nft_expr_ops *ops; + union { + struct nft_payload payload; + struct nft_meta meta; + } __attribute__((aligned(__alignof__(u64)))); +}; + +enum { + NFT_INNER_EXPR_PAYLOAD, + NFT_INNER_EXPR_META, +}; + +struct nft_inner { + u8 flags; + u8 hdrsize; + u8 type; + u8 expr_type; + + struct __nft_expr expr; +}; + +static int nft_inner_parse_l2l3(const struct nft_inner *priv, + const struct nft_pktinfo *pkt, + struct nft_inner_tun_ctx *ctx, u32 off) +{ + __be16 llproto, outer_llproto; + u32 nhoff, thoff; + + if (priv->flags & NFT_INNER_LL) { + struct vlan_ethhdr *veth, _veth; + struct ethhdr *eth, _eth; + u32 hdrsize; + + eth = skb_header_pointer(pkt->skb, off, sizeof(_eth), &_eth); + if (!eth) + return -1; + + switch (eth->h_proto) { + case htons(ETH_P_IP): + case htons(ETH_P_IPV6): + llproto = eth->h_proto; + hdrsize = sizeof(_eth); + break; + case htons(ETH_P_8021Q): + veth = skb_header_pointer(pkt->skb, off, sizeof(_veth), &_veth); + if (!veth) + return -1; + + outer_llproto = veth->h_vlan_encapsulated_proto; + llproto = veth->h_vlan_proto; + hdrsize = sizeof(_veth); + break; + default: + return -1; + } + + ctx->inner_lloff = off; + ctx->flags |= NFT_PAYLOAD_CTX_INNER_LL; + off += hdrsize; + } else { + struct iphdr *iph; + u32 _version; + + iph = skb_header_pointer(pkt->skb, off, sizeof(_version), &_version); + if (!iph) + return -1; + + switch (iph->version) { + case 4: + llproto = htons(ETH_P_IP); + break; + case 6: + llproto = htons(ETH_P_IPV6); + break; + default: + return -1; + } + } + + ctx->llproto = llproto; + if (llproto == htons(ETH_P_8021Q)) + llproto = outer_llproto; + + nhoff = off; + + switch (llproto) { + case htons(ETH_P_IP): { + struct iphdr *iph, _iph; + + iph = skb_header_pointer(pkt->skb, nhoff, sizeof(_iph), &_iph); + if (!iph) + return -1; + + if (iph->ihl < 5 || iph->version != 4) + return -1; + + ctx->inner_nhoff = nhoff; + ctx->flags |= NFT_PAYLOAD_CTX_INNER_NH; + + thoff = nhoff + (iph->ihl * 4); + if ((ntohs(iph->frag_off) & IP_OFFSET) == 0) { + ctx->flags |= NFT_PAYLOAD_CTX_INNER_TH; + ctx->inner_thoff = thoff; + ctx->l4proto = iph->protocol; + } + } + break; + case htons(ETH_P_IPV6): { + struct ipv6hdr *ip6h, _ip6h; + int fh_flags = IP6_FH_F_AUTH; + unsigned short fragoff; + int l4proto; + + ip6h = skb_header_pointer(pkt->skb, nhoff, sizeof(_ip6h), &_ip6h); + if (!ip6h) + return -1; + + if (ip6h->version != 6) + return -1; + + ctx->inner_nhoff = nhoff; + ctx->flags |= NFT_PAYLOAD_CTX_INNER_NH; + + thoff = nhoff; + l4proto = ipv6_find_hdr(pkt->skb, &thoff, -1, &fragoff, &fh_flags); + if (l4proto < 0 || thoff > U16_MAX) + return -1; + + if (fragoff == 0) { + thoff = nhoff + sizeof(_ip6h); + ctx->flags |= NFT_PAYLOAD_CTX_INNER_TH; + ctx->inner_thoff = thoff; + ctx->l4proto = l4proto; + } + } + break; + default: + return -1; + } + + return 0; +} + +static int nft_inner_parse_tunhdr(const struct nft_inner *priv, + const struct nft_pktinfo *pkt, + struct nft_inner_tun_ctx *ctx, u32 *off) +{ + if (pkt->tprot == IPPROTO_GRE) { + ctx->inner_tunoff = pkt->thoff; + ctx->flags |= NFT_PAYLOAD_CTX_INNER_TUN; + return 0; + } + + if (pkt->tprot != IPPROTO_UDP) + return -1; + + ctx->inner_tunoff = *off; + ctx->flags |= NFT_PAYLOAD_CTX_INNER_TUN; + *off += priv->hdrsize; + + switch (priv->type) { + case NFT_INNER_GENEVE: { + struct genevehdr *gnvh, _gnvh; + + gnvh = skb_header_pointer(pkt->skb, pkt->inneroff, + sizeof(_gnvh), &_gnvh); + if (!gnvh) + return -1; + + *off += gnvh->opt_len * 4; + } + break; + default: + break; + } + + return 0; +} + +static int nft_inner_parse(const struct nft_inner *priv, + struct nft_pktinfo *pkt, + struct nft_inner_tun_ctx *tun_ctx) +{ + struct nft_inner_tun_ctx ctx = {}; + u32 off = pkt->inneroff; + + if (priv->flags & NFT_INNER_HDRSIZE && + nft_inner_parse_tunhdr(priv, pkt, &ctx, &off) < 0) + return -1; + + if (priv->flags & (NFT_INNER_LL | NFT_INNER_NH)) { + if (nft_inner_parse_l2l3(priv, pkt, &ctx, off) < 0) + return -1; + } else if (priv->flags & NFT_INNER_TH) { + ctx.inner_thoff = off; + ctx.flags |= NFT_PAYLOAD_CTX_INNER_TH; + } + + *tun_ctx = ctx; + tun_ctx->type = priv->type; + pkt->flags |= NFT_PKTINFO_INNER_FULL; + + return 0; +} + +static bool nft_inner_parse_needed(const struct nft_inner *priv, + const struct nft_pktinfo *pkt, + const struct nft_inner_tun_ctx *tun_ctx) +{ + if (!(pkt->flags & NFT_PKTINFO_INNER_FULL)) + return true; + + if (priv->type != tun_ctx->type) + return true; + + return false; +} + +static void nft_inner_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_inner_tun_ctx *tun_ctx = this_cpu_ptr(&nft_pcpu_tun_ctx); + const struct nft_inner *priv = nft_expr_priv(expr); + + if (nft_payload_inner_offset(pkt) < 0) + goto err; + + if (nft_inner_parse_needed(priv, pkt, tun_ctx) && + nft_inner_parse(priv, (struct nft_pktinfo *)pkt, tun_ctx) < 0) + goto err; + + switch (priv->expr_type) { + case NFT_INNER_EXPR_PAYLOAD: + nft_payload_inner_eval((struct nft_expr *)&priv->expr, regs, pkt, tun_ctx); + break; + case NFT_INNER_EXPR_META: + nft_meta_inner_eval((struct nft_expr *)&priv->expr, regs, pkt, tun_ctx); + break; + default: + WARN_ON_ONCE(1); + goto err; + } + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_inner_policy[NFTA_INNER_MAX + 1] = { + [NFTA_INNER_NUM] = { .type = NLA_U32 }, + [NFTA_INNER_FLAGS] = { .type = NLA_U32 }, + [NFTA_INNER_HDRSIZE] = { .type = NLA_U32 }, + [NFTA_INNER_TYPE] = { .type = NLA_U32 }, + [NFTA_INNER_EXPR] = { .type = NLA_NESTED }, +}; + +struct nft_expr_info { + const struct nft_expr_ops *ops; + const struct nlattr *attr; + struct nlattr *tb[NFT_EXPR_MAXATTR + 1]; +}; + +static int nft_inner_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_inner *priv = nft_expr_priv(expr); + u32 flags, hdrsize, type, num; + struct nft_expr_info expr_info; + int err; + + if (!tb[NFTA_INNER_FLAGS] || + !tb[NFTA_INNER_HDRSIZE] || + !tb[NFTA_INNER_TYPE] || + !tb[NFTA_INNER_EXPR]) + return -EINVAL; + + flags = ntohl(nla_get_be32(tb[NFTA_INNER_FLAGS])); + if (flags & ~NFT_INNER_MASK) + return -EOPNOTSUPP; + + num = ntohl(nla_get_be32(tb[NFTA_INNER_NUM])); + if (num != 0) + return -EOPNOTSUPP; + + hdrsize = ntohl(nla_get_be32(tb[NFTA_INNER_HDRSIZE])); + type = ntohl(nla_get_be32(tb[NFTA_INNER_TYPE])); + + if (type > U8_MAX) + return -EINVAL; + + if (flags & NFT_INNER_HDRSIZE) { + if (hdrsize == 0 || hdrsize > 64) + return -EOPNOTSUPP; + } + + priv->flags = flags; + priv->hdrsize = hdrsize; + priv->type = type; + + err = nft_expr_inner_parse(ctx, tb[NFTA_INNER_EXPR], &expr_info); + if (err < 0) + return err; + + priv->expr.ops = expr_info.ops; + + if (!strcmp(expr_info.ops->type->name, "payload")) + priv->expr_type = NFT_INNER_EXPR_PAYLOAD; + else if (!strcmp(expr_info.ops->type->name, "meta")) + priv->expr_type = NFT_INNER_EXPR_META; + else + return -EINVAL; + + err = expr_info.ops->init(ctx, (struct nft_expr *)&priv->expr, + (const struct nlattr * const*)expr_info.tb); + if (err < 0) + return err; + + return 0; +} + +static int nft_inner_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) +{ + const struct nft_inner *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_INNER_NUM, htonl(0)) || + nla_put_be32(skb, NFTA_INNER_TYPE, htonl(priv->type)) || + nla_put_be32(skb, NFTA_INNER_FLAGS, htonl(priv->flags)) || + nla_put_be32(skb, NFTA_INNER_HDRSIZE, htonl(priv->hdrsize))) + goto nla_put_failure; + + if (nft_expr_dump(skb, NFTA_INNER_EXPR, + (struct nft_expr *)&priv->expr, reset) < 0) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static const struct nft_expr_ops nft_inner_ops = { + .type = &nft_inner_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_inner)), + .eval = nft_inner_eval, + .init = nft_inner_init, + .dump = nft_inner_dump, +}; + +struct nft_expr_type nft_inner_type __read_mostly = { + .name = "inner", + .ops = &nft_inner_ops, + .policy = nft_inner_policy, + .maxattr = NFTA_INNER_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_last.c b/net/netfilter/nft_last.c index bb15a55dad5c..7f2bda6641bd 100644 --- a/net/netfilter/nft_last.c +++ b/net/netfilter/nft_last.c @@ -65,7 +65,8 @@ static void nft_last_eval(const struct nft_expr *expr, WRITE_ONCE(last->set, 1); } -static int nft_last_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_last_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_last_priv *priv = nft_expr_priv(expr); struct nft_last *last = priv->last; diff --git a/net/netfilter/nft_limit.c b/net/netfilter/nft_limit.c index 981addb2d051..145dc62c6247 100644 --- a/net/netfilter/nft_limit.c +++ b/net/netfilter/nft_limit.c @@ -193,7 +193,8 @@ static int nft_limit_pkts_init(const struct nft_ctx *ctx, return 0; } -static int nft_limit_pkts_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_limit_pkts_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_limit_priv_pkts *priv = nft_expr_priv(expr); @@ -251,7 +252,7 @@ static int nft_limit_bytes_init(const struct nft_ctx *ctx, } static int nft_limit_bytes_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_limit_priv *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_log.c b/net/netfilter/nft_log.c index 0e13c003f0c1..5defe6e4fd98 100644 --- a/net/netfilter/nft_log.c +++ b/net/netfilter/nft_log.c @@ -241,7 +241,8 @@ static void nft_log_destroy(const struct nft_ctx *ctx, nf_logger_put(ctx->family, li->type); } -static int nft_log_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_log_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_log *priv = nft_expr_priv(expr); const struct nf_loginfo *li = &priv->loginfo; diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c index dfae12759c7c..cae5a6724163 100644 --- a/net/netfilter/nft_lookup.c +++ b/net/netfilter/nft_lookup.c @@ -178,7 +178,8 @@ static void nft_lookup_destroy(const struct nft_ctx *ctx, nf_tables_destroy_set(ctx, priv->set); } -static int nft_lookup_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_lookup_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_lookup *priv = nft_expr_priv(expr); u32 flags = priv->invert ? NFT_LOOKUP_F_INV : 0; diff --git a/net/netfilter/nft_masq.c b/net/netfilter/nft_masq.c index 2a0adc497bbb..e55e455275c4 100644 --- a/net/netfilter/nft_masq.c +++ b/net/netfilter/nft_masq.c @@ -73,7 +73,8 @@ static int nft_masq_init(const struct nft_ctx *ctx, return nf_ct_netns_get(ctx->net, ctx->family); } -static int nft_masq_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_masq_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_masq *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c index 55d2d49c3425..e384e0de7a54 100644 --- a/net/netfilter/nft_meta.c +++ b/net/netfilter/nft_meta.c @@ -669,7 +669,7 @@ int nft_meta_set_init(const struct nft_ctx *ctx, EXPORT_SYMBOL_GPL(nft_meta_set_init); int nft_meta_get_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_meta *priv = nft_expr_priv(expr); @@ -684,7 +684,8 @@ nla_put_failure: } EXPORT_SYMBOL_GPL(nft_meta_get_dump); -int nft_meta_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +int nft_meta_set_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_meta *priv = nft_expr_priv(expr); @@ -831,9 +832,71 @@ nft_meta_select_ops(const struct nft_ctx *ctx, return ERR_PTR(-EINVAL); } +static int nft_meta_inner_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_meta *priv = nft_expr_priv(expr); + unsigned int len; + + priv->key = ntohl(nla_get_be32(tb[NFTA_META_KEY])); + switch (priv->key) { + case NFT_META_PROTOCOL: + len = sizeof(u16); + break; + case NFT_META_L4PROTO: + len = sizeof(u32); + break; + default: + return -EOPNOTSUPP; + } + priv->len = len; + + return nft_parse_register_store(ctx, tb[NFTA_META_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +void nft_meta_inner_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt, + struct nft_inner_tun_ctx *tun_ctx) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + + switch (priv->key) { + case NFT_META_PROTOCOL: + nft_reg_store16(dest, (__force u16)tun_ctx->llproto); + break; + case NFT_META_L4PROTO: + if (!(tun_ctx->flags & NFT_PAYLOAD_CTX_INNER_TH)) + goto err; + + nft_reg_store8(dest, tun_ctx->l4proto); + break; + default: + WARN_ON_ONCE(1); + goto err; + } + return; + +err: + regs->verdict.code = NFT_BREAK; +} +EXPORT_SYMBOL_GPL(nft_meta_inner_eval); + +static const struct nft_expr_ops nft_meta_inner_ops = { + .type = &nft_meta_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_meta)), + .init = nft_meta_inner_init, + .dump = nft_meta_get_dump, + /* direct call to nft_meta_inner_eval(). */ +}; + struct nft_expr_type nft_meta_type __read_mostly = { .name = "meta", .select_ops = nft_meta_select_ops, + .inner_ops = &nft_meta_inner_ops, .policy = nft_meta_policy, .maxattr = NFTA_META_MAX, .owner = THIS_MODULE, diff --git a/net/netfilter/nft_nat.c b/net/netfilter/nft_nat.c index e5fd6995e4bf..047999150390 100644 --- a/net/netfilter/nft_nat.c +++ b/net/netfilter/nft_nat.c @@ -255,7 +255,8 @@ static int nft_nat_init(const struct nft_ctx *ctx, const struct nft_expr *expr, return nf_ct_netns_get(ctx->net, family); } -static int nft_nat_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_nat_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_nat *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_numgen.c b/net/netfilter/nft_numgen.c index 45d3dc9e96f2..7d29db7c2ac0 100644 --- a/net/netfilter/nft_numgen.c +++ b/net/netfilter/nft_numgen.c @@ -112,7 +112,8 @@ nla_put_failure: return -1; } -static int nft_ng_inc_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_ng_inc_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_ng_inc *priv = nft_expr_priv(expr); @@ -168,7 +169,8 @@ static int nft_ng_random_init(const struct nft_ctx *ctx, NULL, NFT_DATA_VALUE, sizeof(u32)); } -static int nft_ng_random_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_ng_random_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_ng_random *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_objref.c b/net/netfilter/nft_objref.c index 5d8d91b3904d..7b01aa2ef653 100644 --- a/net/netfilter/nft_objref.c +++ b/net/netfilter/nft_objref.c @@ -47,7 +47,8 @@ static int nft_objref_init(const struct nft_ctx *ctx, return 0; } -static int nft_objref_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_objref_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_object *obj = nft_objref_priv(expr); @@ -82,7 +83,6 @@ static void nft_objref_activate(const struct nft_ctx *ctx, obj->use++; } -static struct nft_expr_type nft_objref_type; static const struct nft_expr_ops nft_objref_ops = { .type = &nft_objref_type, .size = NFT_EXPR_SIZE(sizeof(struct nft_object *)), @@ -156,7 +156,8 @@ static int nft_objref_map_init(const struct nft_ctx *ctx, return 0; } -static int nft_objref_map_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_objref_map_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_objref_map *priv = nft_expr_priv(expr); @@ -195,7 +196,6 @@ static void nft_objref_map_destroy(const struct nft_ctx *ctx, nf_tables_destroy_set(ctx, priv->set); } -static struct nft_expr_type nft_objref_type; static const struct nft_expr_ops nft_objref_map_ops = { .type = &nft_objref_type, .size = NFT_EXPR_SIZE(sizeof(struct nft_objref_map)), @@ -233,28 +233,10 @@ static const struct nla_policy nft_objref_policy[NFTA_OBJREF_MAX + 1] = { [NFTA_OBJREF_SET_ID] = { .type = NLA_U32 }, }; -static struct nft_expr_type nft_objref_type __read_mostly = { +struct nft_expr_type nft_objref_type __read_mostly = { .name = "objref", .select_ops = nft_objref_select_ops, .policy = nft_objref_policy, .maxattr = NFTA_OBJREF_MAX, .owner = THIS_MODULE, }; - -static int __init nft_objref_module_init(void) -{ - return nft_register_expr(&nft_objref_type); -} - -static void __exit nft_objref_module_exit(void) -{ - nft_unregister_expr(&nft_objref_type); -} - -module_init(nft_objref_module_init); -module_exit(nft_objref_module_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Pablo Neira Ayuso <pablo@netfilter.org>"); -MODULE_ALIAS_NFT_EXPR("objref"); -MODULE_DESCRIPTION("nftables stateful object reference module"); diff --git a/net/netfilter/nft_osf.c b/net/netfilter/nft_osf.c index adacf95b6e2b..70820c66b591 100644 --- a/net/netfilter/nft_osf.c +++ b/net/netfilter/nft_osf.c @@ -92,7 +92,8 @@ static int nft_osf_init(const struct nft_ctx *ctx, return 0; } -static int nft_osf_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_osf_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_osf *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_payload.c b/net/netfilter/nft_payload.c index 4edd899aeb9b..3a3c7746e88f 100644 --- a/net/netfilter/nft_payload.c +++ b/net/netfilter/nft_payload.c @@ -19,6 +19,7 @@ /* For layer 4 checksum field offset. */ #include <linux/tcp.h> #include <linux/udp.h> +#include <net/gre.h> #include <linux/icmpv6.h> #include <linux/ip.h> #include <linux/ipv6.h> @@ -62,7 +63,7 @@ nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len) return false; if (offset + len > VLAN_ETH_HLEN + vlan_hlen) - ethlen -= offset + len - VLAN_ETH_HLEN + vlan_hlen; + ethlen -= offset + len - VLAN_ETH_HLEN - vlan_hlen; memcpy(dst_u8, vlanh + offset - vlan_hlen, ethlen); @@ -100,6 +101,41 @@ static int __nft_payload_inner_offset(struct nft_pktinfo *pkt) pkt->inneroff = thoff + __tcp_hdrlen(th); } break; + case IPPROTO_GRE: { + u32 offset = sizeof(struct gre_base_hdr); + struct gre_base_hdr *gre, _gre; + __be16 version; + + gre = skb_header_pointer(pkt->skb, thoff, sizeof(_gre), &_gre); + if (!gre) + return -1; + + version = gre->flags & GRE_VERSION; + switch (version) { + case GRE_VERSION_0: + if (gre->flags & GRE_ROUTING) + return -1; + + if (gre->flags & GRE_CSUM) { + offset += sizeof_field(struct gre_full_hdr, csum) + + sizeof_field(struct gre_full_hdr, reserved1); + } + if (gre->flags & GRE_KEY) + offset += sizeof_field(struct gre_full_hdr, key); + + if (gre->flags & GRE_SEQ) + offset += sizeof_field(struct gre_full_hdr, seq); + break; + default: + return -1; + } + + pkt->inneroff = thoff + offset; + } + break; + case IPPROTO_IPIP: + pkt->inneroff = thoff; + break; default: return -1; } @@ -109,7 +145,7 @@ static int __nft_payload_inner_offset(struct nft_pktinfo *pkt) return 0; } -static int nft_payload_inner_offset(const struct nft_pktinfo *pkt) +int nft_payload_inner_offset(const struct nft_pktinfo *pkt) { if (!(pkt->flags & NFT_PKTINFO_INNER) && __nft_payload_inner_offset((struct nft_pktinfo *)pkt) < 0) @@ -195,7 +231,8 @@ static int nft_payload_init(const struct nft_ctx *ctx, priv->len); } -static int nft_payload_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_payload_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_payload *priv = nft_expr_priv(expr); @@ -552,6 +589,92 @@ const struct nft_expr_ops nft_payload_fast_ops = { .offload = nft_payload_offload, }; +void nft_payload_inner_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt, + struct nft_inner_tun_ctx *tun_ctx) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + const struct sk_buff *skb = pkt->skb; + u32 *dest = ®s->data[priv->dreg]; + int offset; + + if (priv->len % NFT_REG32_SIZE) + dest[priv->len / NFT_REG32_SIZE] = 0; + + switch (priv->base) { + case NFT_PAYLOAD_TUN_HEADER: + if (!(tun_ctx->flags & NFT_PAYLOAD_CTX_INNER_TUN)) + goto err; + + offset = tun_ctx->inner_tunoff; + break; + case NFT_PAYLOAD_LL_HEADER: + if (!(tun_ctx->flags & NFT_PAYLOAD_CTX_INNER_LL)) + goto err; + + offset = tun_ctx->inner_lloff; + break; + case NFT_PAYLOAD_NETWORK_HEADER: + if (!(tun_ctx->flags & NFT_PAYLOAD_CTX_INNER_NH)) + goto err; + + offset = tun_ctx->inner_nhoff; + break; + case NFT_PAYLOAD_TRANSPORT_HEADER: + if (!(tun_ctx->flags & NFT_PAYLOAD_CTX_INNER_TH)) + goto err; + + offset = tun_ctx->inner_thoff; + break; + default: + WARN_ON_ONCE(1); + goto err; + } + offset += priv->offset; + + if (skb_copy_bits(skb, offset, dest, priv->len) < 0) + goto err; + + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static int nft_payload_inner_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_payload *priv = nft_expr_priv(expr); + u32 base; + + base = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_BASE])); + switch (base) { + case NFT_PAYLOAD_TUN_HEADER: + case NFT_PAYLOAD_LL_HEADER: + case NFT_PAYLOAD_NETWORK_HEADER: + case NFT_PAYLOAD_TRANSPORT_HEADER: + break; + default: + return -EOPNOTSUPP; + } + + priv->base = base; + priv->offset = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_OFFSET])); + priv->len = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_LEN])); + + return nft_parse_register_store(ctx, tb[NFTA_PAYLOAD_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + priv->len); +} + +static const struct nft_expr_ops nft_payload_inner_ops = { + .type = &nft_payload_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_payload)), + .init = nft_payload_inner_init, + .dump = nft_payload_dump, + /* direct call to nft_payload_inner_eval(). */ +}; + static inline void nft_csum_replace(__sum16 *sum, __wsum fsum, __wsum tsum) { *sum = csum_fold(csum_add(csum_sub(~csum_unfold(*sum), fsum), tsum)); @@ -665,6 +788,16 @@ static int nft_payload_csum_inet(struct sk_buff *skb, const u32 *src, return 0; } +struct nft_payload_set { + enum nft_payload_bases base:8; + u8 offset; + u8 len; + u8 sreg; + u8 csum_type; + u8 csum_offset; + u8 csum_flags; +}; + static void nft_payload_set_eval(const struct nft_expr *expr, struct nft_regs *regs, const struct nft_pktinfo *pkt) @@ -787,7 +920,8 @@ static int nft_payload_set_init(const struct nft_ctx *ctx, priv->len); } -static int nft_payload_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_payload_set_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_payload_set *priv = nft_expr_priv(expr); @@ -885,6 +1019,7 @@ nft_payload_select_ops(const struct nft_ctx *ctx, struct nft_expr_type nft_payload_type __read_mostly = { .name = "payload", .select_ops = nft_payload_select_ops, + .inner_ops = &nft_payload_inner_ops, .policy = nft_payload_policy, .maxattr = NFTA_PAYLOAD_MAX, .owner = THIS_MODULE, diff --git a/net/netfilter/nft_queue.c b/net/netfilter/nft_queue.c index da29e92c03e2..b2b8127c8d43 100644 --- a/net/netfilter/nft_queue.c +++ b/net/netfilter/nft_queue.c @@ -152,7 +152,8 @@ static int nft_queue_sreg_init(const struct nft_ctx *ctx, return 0; } -static int nft_queue_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_queue_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_queue *priv = nft_expr_priv(expr); @@ -168,7 +169,8 @@ nla_put_failure: } static int -nft_queue_sreg_dump(struct sk_buff *skb, const struct nft_expr *expr) +nft_queue_sreg_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_queue *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_quota.c b/net/netfilter/nft_quota.c index e6b0df68feea..123578e28917 100644 --- a/net/netfilter/nft_quota.c +++ b/net/netfilter/nft_quota.c @@ -217,11 +217,12 @@ static int nft_quota_init(const struct nft_ctx *ctx, return nft_quota_do_init(tb, priv); } -static int nft_quota_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_quota_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_quota *priv = nft_expr_priv(expr); - return nft_quota_do_dump(skb, priv, false); + return nft_quota_do_dump(skb, priv, reset); } static void nft_quota_destroy(const struct nft_ctx *ctx, diff --git a/net/netfilter/nft_range.c b/net/netfilter/nft_range.c index 832f0d725a9e..0566d6aaf1e5 100644 --- a/net/netfilter/nft_range.c +++ b/net/netfilter/nft_range.c @@ -111,7 +111,8 @@ err1: return err; } -static int nft_range_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_range_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_range_expr *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_redir.c b/net/netfilter/nft_redir.c index 5086adfe731c..5f7739987559 100644 --- a/net/netfilter/nft_redir.c +++ b/net/netfilter/nft_redir.c @@ -75,7 +75,8 @@ static int nft_redir_init(const struct nft_ctx *ctx, return nf_ct_netns_get(ctx->net, ctx->family); } -static int nft_redir_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_redir_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_redir *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_reject.c b/net/netfilter/nft_reject.c index 927ff8459bd9..f2addc844dd2 100644 --- a/net/netfilter/nft_reject.c +++ b/net/netfilter/nft_reject.c @@ -69,7 +69,8 @@ int nft_reject_init(const struct nft_ctx *ctx, } EXPORT_SYMBOL_GPL(nft_reject_init); -int nft_reject_dump(struct sk_buff *skb, const struct nft_expr *expr) +int nft_reject_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { const struct nft_reject *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_rt.c b/net/netfilter/nft_rt.c index 71931ec91721..5990fdd7b3cc 100644 --- a/net/netfilter/nft_rt.c +++ b/net/netfilter/nft_rt.c @@ -146,7 +146,7 @@ static int nft_rt_get_init(const struct nft_ctx *ctx, } static int nft_rt_get_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_rt *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_set_pipapo.c b/net/netfilter/nft_set_pipapo.c index 4f9299b9dcdd..06d46d182634 100644 --- a/net/netfilter/nft_set_pipapo.c +++ b/net/netfilter/nft_set_pipapo.c @@ -1162,6 +1162,7 @@ static int nft_pipapo_insert(const struct net *net, const struct nft_set *set, struct nft_pipapo_match *m = priv->clone; u8 genmask = nft_genmask_next(net); struct nft_pipapo_field *f; + const u8 *start_p, *end_p; int i, bsize_max, err = 0; if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY_END)) @@ -1202,9 +1203,9 @@ static int nft_pipapo_insert(const struct net *net, const struct nft_set *set, } /* Validate */ + start_p = start; + end_p = end; nft_pipapo_for_each_field(f, i, m) { - const u8 *start_p = start, *end_p = end; - if (f->rules >= (unsigned long)NFT_PIPAPO_RULE0_MAX) return -ENOSPC; diff --git a/net/netfilter/nft_set_rbtree.c b/net/netfilter/nft_set_rbtree.c index 7325bee7d144..19ea4d3c3553 100644 --- a/net/netfilter/nft_set_rbtree.c +++ b/net/netfilter/nft_set_rbtree.c @@ -38,10 +38,12 @@ static bool nft_rbtree_interval_start(const struct nft_rbtree_elem *rbe) return !nft_rbtree_interval_end(rbe); } -static bool nft_rbtree_equal(const struct nft_set *set, const void *this, - const struct nft_rbtree_elem *interval) +static int nft_rbtree_cmp(const struct nft_set *set, + const struct nft_rbtree_elem *e1, + const struct nft_rbtree_elem *e2) { - return memcmp(this, nft_set_ext_key(&interval->ext), set->klen) == 0; + return memcmp(nft_set_ext_key(&e1->ext), nft_set_ext_key(&e2->ext), + set->klen); } static bool __nft_rbtree_lookup(const struct net *net, const struct nft_set *set, @@ -52,7 +54,6 @@ static bool __nft_rbtree_lookup(const struct net *net, const struct nft_set *set const struct nft_rbtree_elem *rbe, *interval = NULL; u8 genmask = nft_genmask_cur(net); const struct rb_node *parent; - const void *this; int d; parent = rcu_dereference_raw(priv->root.rb_node); @@ -62,12 +63,11 @@ static bool __nft_rbtree_lookup(const struct net *net, const struct nft_set *set rbe = rb_entry(parent, struct nft_rbtree_elem, node); - this = nft_set_ext_key(&rbe->ext); - d = memcmp(this, key, set->klen); + d = memcmp(nft_set_ext_key(&rbe->ext), key, set->klen); if (d < 0) { parent = rcu_dereference_raw(parent->rb_left); if (interval && - nft_rbtree_equal(set, this, interval) && + !nft_rbtree_cmp(set, rbe, interval) && nft_rbtree_interval_end(rbe) && nft_rbtree_interval_start(interval)) continue; @@ -215,154 +215,216 @@ static void *nft_rbtree_get(const struct net *net, const struct nft_set *set, return rbe; } +static int nft_rbtree_gc_elem(const struct nft_set *__set, + struct nft_rbtree *priv, + struct nft_rbtree_elem *rbe) +{ + struct nft_set *set = (struct nft_set *)__set; + struct rb_node *prev = rb_prev(&rbe->node); + struct nft_rbtree_elem *rbe_prev; + struct nft_set_gc_batch *gcb; + + gcb = nft_set_gc_batch_check(set, NULL, GFP_ATOMIC); + if (!gcb) + return -ENOMEM; + + /* search for expired end interval coming before this element. */ + do { + rbe_prev = rb_entry(prev, struct nft_rbtree_elem, node); + if (nft_rbtree_interval_end(rbe_prev)) + break; + + prev = rb_prev(prev); + } while (prev != NULL); + + rb_erase(&rbe_prev->node, &priv->root); + rb_erase(&rbe->node, &priv->root); + atomic_sub(2, &set->nelems); + + nft_set_gc_batch_add(gcb, rbe); + nft_set_gc_batch_complete(gcb); + + return 0; +} + +static bool nft_rbtree_update_first(const struct nft_set *set, + struct nft_rbtree_elem *rbe, + struct rb_node *first) +{ + struct nft_rbtree_elem *first_elem; + + first_elem = rb_entry(first, struct nft_rbtree_elem, node); + /* this element is closest to where the new element is to be inserted: + * update the first element for the node list path. + */ + if (nft_rbtree_cmp(set, rbe, first_elem) < 0) + return true; + + return false; +} + static int __nft_rbtree_insert(const struct net *net, const struct nft_set *set, struct nft_rbtree_elem *new, struct nft_set_ext **ext) { - bool overlap = false, dup_end_left = false, dup_end_right = false; + struct nft_rbtree_elem *rbe, *rbe_le = NULL, *rbe_ge = NULL; + struct rb_node *node, *parent, **p, *first = NULL; struct nft_rbtree *priv = nft_set_priv(set); u8 genmask = nft_genmask_next(net); - struct nft_rbtree_elem *rbe; - struct rb_node *parent, **p; - int d; + int d, err; - /* Detect overlaps as we descend the tree. Set the flag in these cases: - * - * a1. _ _ __>| ?_ _ __| (insert end before existing end) - * a2. _ _ ___| ?_ _ _>| (insert end after existing end) - * a3. _ _ ___? >|_ _ __| (insert start before existing end) - * - * and clear it later on, as we eventually reach the points indicated by - * '?' above, in the cases described below. We'll always meet these - * later, locally, due to tree ordering, and overlaps for the intervals - * that are the closest together are always evaluated last. - * - * b1. _ _ __>| !_ _ __| (insert end before existing start) - * b2. _ _ ___| !_ _ _>| (insert end after existing start) - * b3. _ _ ___! >|_ _ __| (insert start after existing end, as a leaf) - * '--' no nodes falling in this range - * b4. >|_ _ ! (insert start before existing start) - * - * Case a3. resolves to b3.: - * - if the inserted start element is the leftmost, because the '0' - * element in the tree serves as end element - * - otherwise, if an existing end is found immediately to the left. If - * there are existing nodes in between, we need to further descend the - * tree before we can conclude the new start isn't causing an overlap - * - * or to b4., which, preceded by a3., means we already traversed one or - * more existing intervals entirely, from the right. - * - * For a new, rightmost pair of elements, we'll hit cases b3. and b2., - * in that order. - * - * The flag is also cleared in two special cases: - * - * b5. |__ _ _!|<_ _ _ (insert start right before existing end) - * b6. |__ _ >|!__ _ _ (insert end right after existing start) - * - * which always happen as last step and imply that no further - * overlapping is possible. - * - * Another special case comes from the fact that start elements matching - * an already existing start element are allowed: insertion is not - * performed but we return -EEXIST in that case, and the error will be - * cleared by the caller if NLM_F_EXCL is not present in the request. - * This way, request for insertion of an exact overlap isn't reported as - * error to userspace if not desired. - * - * However, if the existing start matches a pre-existing start, but the - * end element doesn't match the corresponding pre-existing end element, - * we need to report a partial overlap. This is a local condition that - * can be noticed without need for a tracking flag, by checking for a - * local duplicated end for a corresponding start, from left and right, - * separately. + /* Descend the tree to search for an existing element greater than the + * key value to insert that is greater than the new element. This is the + * first element to walk the ordered elements to find possible overlap. */ - parent = NULL; p = &priv->root.rb_node; while (*p != NULL) { parent = *p; rbe = rb_entry(parent, struct nft_rbtree_elem, node); - d = memcmp(nft_set_ext_key(&rbe->ext), - nft_set_ext_key(&new->ext), - set->klen); + d = nft_rbtree_cmp(set, rbe, new); + if (d < 0) { p = &parent->rb_left; - - if (nft_rbtree_interval_start(new)) { - if (nft_rbtree_interval_end(rbe) && - nft_set_elem_active(&rbe->ext, genmask) && - !nft_set_elem_expired(&rbe->ext) && !*p) - overlap = false; - } else { - if (dup_end_left && !*p) - return -ENOTEMPTY; - - overlap = nft_rbtree_interval_end(rbe) && - nft_set_elem_active(&rbe->ext, - genmask) && - !nft_set_elem_expired(&rbe->ext); - - if (overlap) { - dup_end_right = true; - continue; - } - } } else if (d > 0) { - p = &parent->rb_right; + if (!first || + nft_rbtree_update_first(set, rbe, first)) + first = &rbe->node; - if (nft_rbtree_interval_end(new)) { - if (dup_end_right && !*p) - return -ENOTEMPTY; - - overlap = nft_rbtree_interval_end(rbe) && - nft_set_elem_active(&rbe->ext, - genmask) && - !nft_set_elem_expired(&rbe->ext); - - if (overlap) { - dup_end_left = true; - continue; - } - } else if (nft_set_elem_active(&rbe->ext, genmask) && - !nft_set_elem_expired(&rbe->ext)) { - overlap = nft_rbtree_interval_end(rbe); - } + p = &parent->rb_right; } else { - if (nft_rbtree_interval_end(rbe) && - nft_rbtree_interval_start(new)) { + if (nft_rbtree_interval_end(rbe)) p = &parent->rb_left; - - if (nft_set_elem_active(&rbe->ext, genmask) && - !nft_set_elem_expired(&rbe->ext)) - overlap = false; - } else if (nft_rbtree_interval_start(rbe) && - nft_rbtree_interval_end(new)) { + else p = &parent->rb_right; + } + } + + if (!first) + first = rb_first(&priv->root); + + /* Detect overlap by going through the list of valid tree nodes. + * Values stored in the tree are in reversed order, starting from + * highest to lowest value. + */ + for (node = first; node != NULL; node = rb_next(node)) { + rbe = rb_entry(node, struct nft_rbtree_elem, node); - if (nft_set_elem_active(&rbe->ext, genmask) && - !nft_set_elem_expired(&rbe->ext)) - overlap = false; - } else if (nft_set_elem_active(&rbe->ext, genmask) && - !nft_set_elem_expired(&rbe->ext)) { - *ext = &rbe->ext; - return -EEXIST; - } else { - overlap = false; - if (nft_rbtree_interval_end(rbe)) - p = &parent->rb_left; - else - p = &parent->rb_right; + if (!nft_set_elem_active(&rbe->ext, genmask)) + continue; + + /* perform garbage collection to avoid bogus overlap reports. */ + if (nft_set_elem_expired(&rbe->ext)) { + err = nft_rbtree_gc_elem(set, priv, rbe); + if (err < 0) + return err; + + continue; + } + + d = nft_rbtree_cmp(set, rbe, new); + if (d == 0) { + /* Matching end element: no need to look for an + * overlapping greater or equal element. + */ + if (nft_rbtree_interval_end(rbe)) { + rbe_le = rbe; + break; + } + + /* first element that is greater or equal to key value. */ + if (!rbe_ge) { + rbe_ge = rbe; + continue; + } + + /* this is a closer more or equal element, update it. */ + if (nft_rbtree_cmp(set, rbe_ge, new) != 0) { + rbe_ge = rbe; + continue; + } + + /* element is equal to key value, make sure flags are + * the same, an existing more or equal start element + * must not be replaced by more or equal end element. + */ + if ((nft_rbtree_interval_start(new) && + nft_rbtree_interval_start(rbe_ge)) || + (nft_rbtree_interval_end(new) && + nft_rbtree_interval_end(rbe_ge))) { + rbe_ge = rbe; + continue; } + } else if (d > 0) { + /* annotate element greater than the new element. */ + rbe_ge = rbe; + continue; + } else if (d < 0) { + /* annotate element less than the new element. */ + rbe_le = rbe; + break; } + } - dup_end_left = dup_end_right = false; + /* - new start element matching existing start element: full overlap + * reported as -EEXIST, cleared by caller if NLM_F_EXCL is not given. + */ + if (rbe_ge && !nft_rbtree_cmp(set, new, rbe_ge) && + nft_rbtree_interval_start(rbe_ge) == nft_rbtree_interval_start(new)) { + *ext = &rbe_ge->ext; + return -EEXIST; } - if (overlap) + /* - new end element matching existing end element: full overlap + * reported as -EEXIST, cleared by caller if NLM_F_EXCL is not given. + */ + if (rbe_le && !nft_rbtree_cmp(set, new, rbe_le) && + nft_rbtree_interval_end(rbe_le) == nft_rbtree_interval_end(new)) { + *ext = &rbe_le->ext; + return -EEXIST; + } + + /* - new start element with existing closest, less or equal key value + * being a start element: partial overlap, reported as -ENOTEMPTY. + * Anonymous sets allow for two consecutive start element since they + * are constant, skip them to avoid bogus overlap reports. + */ + if (!nft_set_is_anonymous(set) && rbe_le && + nft_rbtree_interval_start(rbe_le) && nft_rbtree_interval_start(new)) + return -ENOTEMPTY; + + /* - new end element with existing closest, less or equal key value + * being a end element: partial overlap, reported as -ENOTEMPTY. + */ + if (rbe_le && + nft_rbtree_interval_end(rbe_le) && nft_rbtree_interval_end(new)) return -ENOTEMPTY; + /* - new end element with existing closest, greater or equal key value + * being an end element: partial overlap, reported as -ENOTEMPTY + */ + if (rbe_ge && + nft_rbtree_interval_end(rbe_ge) && nft_rbtree_interval_end(new)) + return -ENOTEMPTY; + + /* Accepted element: pick insertion point depending on key value */ + parent = NULL; + p = &priv->root.rb_node; + while (*p != NULL) { + parent = *p; + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + d = nft_rbtree_cmp(set, rbe, new); + + if (d < 0) + p = &parent->rb_left; + else if (d > 0) + p = &parent->rb_right; + else if (nft_rbtree_interval_end(rbe)) + p = &parent->rb_left; + else + p = &parent->rb_right; + } + rb_link_node_rcu(&new->node, parent, p); rb_insert_color(&new->node, &priv->root); return 0; @@ -501,23 +563,37 @@ static void nft_rbtree_gc(struct work_struct *work) struct nft_rbtree *priv; struct rb_node *node; struct nft_set *set; + struct net *net; + u8 genmask; priv = container_of(work, struct nft_rbtree, gc_work.work); set = nft_set_container_of(priv); + net = read_pnet(&set->net); + genmask = nft_genmask_cur(net); write_lock_bh(&priv->lock); write_seqcount_begin(&priv->count); for (node = rb_first(&priv->root); node != NULL; node = rb_next(node)) { rbe = rb_entry(node, struct nft_rbtree_elem, node); + if (!nft_set_elem_active(&rbe->ext, genmask)) + continue; + + /* elements are reversed in the rbtree for historical reasons, + * from highest to lowest value, that is why end element is + * always visited before the start element. + */ if (nft_rbtree_interval_end(rbe)) { rbe_end = rbe; continue; } if (!nft_set_elem_expired(&rbe->ext)) continue; - if (nft_set_elem_mark_busy(&rbe->ext)) + + if (nft_set_elem_mark_busy(&rbe->ext)) { + rbe_end = NULL; continue; + } if (rbe_prev) { rb_erase(&rbe_prev->node, &priv->root); diff --git a/net/netfilter/nft_socket.c b/net/netfilter/nft_socket.c index 49a5348a6a14..85f8df87efda 100644 --- a/net/netfilter/nft_socket.c +++ b/net/netfilter/nft_socket.c @@ -199,7 +199,7 @@ static int nft_socket_init(const struct nft_ctx *ctx, } static int nft_socket_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_socket *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_synproxy.c b/net/netfilter/nft_synproxy.c index 6cf9a04fbfe2..13da882669a4 100644 --- a/net/netfilter/nft_synproxy.c +++ b/net/netfilter/nft_synproxy.c @@ -272,7 +272,8 @@ static void nft_synproxy_destroy(const struct nft_ctx *ctx, nft_synproxy_do_destroy(ctx); } -static int nft_synproxy_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_synproxy_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_synproxy *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_tproxy.c b/net/netfilter/nft_tproxy.c index 62da25ad264b..ea83f661417e 100644 --- a/net/netfilter/nft_tproxy.c +++ b/net/netfilter/nft_tproxy.c @@ -294,7 +294,7 @@ static void nft_tproxy_destroy(const struct nft_ctx *ctx, } static int nft_tproxy_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_tproxy *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_tunnel.c b/net/netfilter/nft_tunnel.c index 983ade4be3b3..b059aa541798 100644 --- a/net/netfilter/nft_tunnel.c +++ b/net/netfilter/nft_tunnel.c @@ -108,7 +108,7 @@ static int nft_tunnel_get_init(const struct nft_ctx *ctx, } static int nft_tunnel_get_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_tunnel *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_xfrm.c b/net/netfilter/nft_xfrm.c index 1c5343c936a8..c88fd078a9ae 100644 --- a/net/netfilter/nft_xfrm.c +++ b/net/netfilter/nft_xfrm.c @@ -212,7 +212,7 @@ static void nft_xfrm_get_eval(const struct nft_expr *expr, } static int nft_xfrm_get_dump(struct sk_buff *skb, - const struct nft_expr *expr) + const struct nft_expr *expr, bool reset) { const struct nft_xfrm *priv = nft_expr_priv(expr); diff --git a/net/netfilter/xt_IDLETIMER.c b/net/netfilter/xt_IDLETIMER.c index 0f8bb0bf558f..8d36303f3935 100644 --- a/net/netfilter/xt_IDLETIMER.c +++ b/net/netfilter/xt_IDLETIMER.c @@ -413,7 +413,7 @@ static void idletimer_tg_destroy(const struct xt_tgdtor_param *par) pr_debug("deleting timer %s\n", info->label); list_del(&info->timer->entry); - del_timer_sync(&info->timer->timer); + timer_shutdown_sync(&info->timer->timer); cancel_work_sync(&info->timer->work); sysfs_remove_file(idletimer_tg_kobj, &info->timer->attr.attr); kfree(info->timer->attr.attr.name); @@ -441,7 +441,7 @@ static void idletimer_tg_destroy_v1(const struct xt_tgdtor_param *par) if (info->timer->timer_type & XT_IDLETIMER_ALARM) { alarm_cancel(&info->timer->alarm); } else { - del_timer_sync(&info->timer->timer); + timer_shutdown_sync(&info->timer->timer); } cancel_work_sync(&info->timer->work); sysfs_remove_file(idletimer_tg_kobj, &info->timer->attr.attr); diff --git a/net/netfilter/xt_LED.c b/net/netfilter/xt_LED.c index 0371c387b0d1..66b0f941d8fb 100644 --- a/net/netfilter/xt_LED.c +++ b/net/netfilter/xt_LED.c @@ -166,7 +166,7 @@ static void led_tg_destroy(const struct xt_tgdtor_param *par) list_del(&ledinternal->list); - del_timer_sync(&ledinternal->timer); + timer_shutdown_sync(&ledinternal->timer); led_trigger_unregister(&ledinternal->netfilter_led_trigger); diff --git a/net/netfilter/xt_connmark.c b/net/netfilter/xt_connmark.c index e5ebc0810675..ad3c033db64e 100644 --- a/net/netfilter/xt_connmark.c +++ b/net/netfilter/xt_connmark.c @@ -30,6 +30,7 @@ connmark_tg_shift(struct sk_buff *skb, const struct xt_connmark_tginfo2 *info) u_int32_t new_targetmark; struct nf_conn *ct; u_int32_t newmark; + u_int32_t oldmark; ct = nf_ct_get(skb, &ctinfo); if (ct == NULL) @@ -37,14 +38,15 @@ connmark_tg_shift(struct sk_buff *skb, const struct xt_connmark_tginfo2 *info) switch (info->mode) { case XT_CONNMARK_SET: - newmark = (ct->mark & ~info->ctmask) ^ info->ctmark; + oldmark = READ_ONCE(ct->mark); + newmark = (oldmark & ~info->ctmask) ^ info->ctmark; if (info->shift_dir == D_SHIFT_RIGHT) newmark >>= info->shift_bits; else newmark <<= info->shift_bits; - if (ct->mark != newmark) { - ct->mark = newmark; + if (READ_ONCE(ct->mark) != newmark) { + WRITE_ONCE(ct->mark, newmark); nf_conntrack_event_cache(IPCT_MARK, ct); } break; @@ -55,15 +57,15 @@ connmark_tg_shift(struct sk_buff *skb, const struct xt_connmark_tginfo2 *info) else new_targetmark <<= info->shift_bits; - newmark = (ct->mark & ~info->ctmask) ^ + newmark = (READ_ONCE(ct->mark) & ~info->ctmask) ^ new_targetmark; - if (ct->mark != newmark) { - ct->mark = newmark; + if (READ_ONCE(ct->mark) != newmark) { + WRITE_ONCE(ct->mark, newmark); nf_conntrack_event_cache(IPCT_MARK, ct); } break; case XT_CONNMARK_RESTORE: - new_targetmark = (ct->mark & info->ctmask); + new_targetmark = (READ_ONCE(ct->mark) & info->ctmask); if (info->shift_dir == D_SHIFT_RIGHT) new_targetmark >>= info->shift_bits; else @@ -126,7 +128,7 @@ connmark_mt(const struct sk_buff *skb, struct xt_action_param *par) if (ct == NULL) return false; - return ((ct->mark & info->mask) == info->mark) ^ info->invert; + return ((READ_ONCE(ct->mark) & info->mask) == info->mark) ^ info->invert; } static int connmark_mt_check(const struct xt_mtchk_param *par) diff --git a/net/netfilter/xt_sctp.c b/net/netfilter/xt_sctp.c index 680015ba7cb6..e8961094a282 100644 --- a/net/netfilter/xt_sctp.c +++ b/net/netfilter/xt_sctp.c @@ -4,7 +4,6 @@ #include <linux/skbuff.h> #include <net/ip.h> #include <net/ipv6.h> -#include <net/sctp/sctp.h> #include <linux/sctp.h> #include <linux/netfilter/x_tables.h> diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index a662e8a5ff84..c64277659753 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -580,7 +580,9 @@ static int netlink_insert(struct sock *sk, u32 portid) if (nlk_sk(sk)->bound) goto err; - nlk_sk(sk)->portid = portid; + /* portid can be read locklessly from netlink_getname(). */ + WRITE_ONCE(nlk_sk(sk)->portid, portid); + sock_hold(sk); err = __netlink_insert(table, sk); @@ -812,6 +814,17 @@ static int netlink_release(struct socket *sock) } sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); + + /* Because struct net might disappear soon, do not keep a pointer. */ + if (!sk->sk_net_refcnt && sock_net(sk) != &init_net) { + __netns_tracker_free(sock_net(sk), &sk->ns_tracker, false); + /* Because of deferred_put_nlk_sk and use of work queue, + * it is possible netns will be freed before this socket. + */ + sock_net_set(sk, &init_net); + __netns_tracker_alloc(&init_net, &sk->ns_tracker, + false, GFP_KERNEL); + } call_rcu(&nlk->rcu, deferred_put_nlk_sk); return 0; } @@ -835,7 +848,7 @@ retry: /* Bind collision, search negative portid values. */ if (rover == -4096) /* rover will be in range [S32_MIN, -4097] */ - rover = S32_MIN + prandom_u32_max(-4096 - S32_MIN); + rover = S32_MIN + get_random_u32_below(-4096 - S32_MIN); else if (rover >= -4096) rover = -4097; portid = rover--; @@ -1085,9 +1098,11 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr, return -EINVAL; if (addr->sa_family == AF_UNSPEC) { - sk->sk_state = NETLINK_UNCONNECTED; - nlk->dst_portid = 0; - nlk->dst_group = 0; + /* paired with READ_ONCE() in netlink_getsockbyportid() */ + WRITE_ONCE(sk->sk_state, NETLINK_UNCONNECTED); + /* dst_portid and dst_group can be read locklessly */ + WRITE_ONCE(nlk->dst_portid, 0); + WRITE_ONCE(nlk->dst_group, 0); return 0; } if (addr->sa_family != AF_NETLINK) @@ -1108,9 +1123,11 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr, err = netlink_autobind(sock); if (err == 0) { - sk->sk_state = NETLINK_CONNECTED; - nlk->dst_portid = nladdr->nl_pid; - nlk->dst_group = ffs(nladdr->nl_groups); + /* paired with READ_ONCE() in netlink_getsockbyportid() */ + WRITE_ONCE(sk->sk_state, NETLINK_CONNECTED); + /* dst_portid and dst_group can be read locklessly */ + WRITE_ONCE(nlk->dst_portid, nladdr->nl_pid); + WRITE_ONCE(nlk->dst_group, ffs(nladdr->nl_groups)); } return err; @@ -1127,10 +1144,12 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, nladdr->nl_pad = 0; if (peer) { - nladdr->nl_pid = nlk->dst_portid; - nladdr->nl_groups = netlink_group_mask(nlk->dst_group); + /* Paired with WRITE_ONCE() in netlink_connect() */ + nladdr->nl_pid = READ_ONCE(nlk->dst_portid); + nladdr->nl_groups = netlink_group_mask(READ_ONCE(nlk->dst_group)); } else { - nladdr->nl_pid = nlk->portid; + /* Paired with WRITE_ONCE() in netlink_insert() */ + nladdr->nl_pid = READ_ONCE(nlk->portid); netlink_lock_table(); nladdr->nl_groups = nlk->groups ? nlk->groups[0] : 0; netlink_unlock_table(); @@ -1157,8 +1176,9 @@ static struct sock *netlink_getsockbyportid(struct sock *ssk, u32 portid) /* Don't bother queuing skb if kernel socket has no input function */ nlk = nlk_sk(sock); - if (sock->sk_state == NETLINK_CONNECTED && - nlk->dst_portid != nlk_sk(ssk)->portid) { + /* dst_portid and sk_state can be changed in netlink_connect() */ + if (READ_ONCE(sock->sk_state) == NETLINK_CONNECTED && + READ_ONCE(nlk->dst_portid) != nlk_sk(ssk)->portid) { sock_put(sock); return ERR_PTR(-ECONNREFUSED); } @@ -1875,8 +1895,9 @@ static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) goto out; netlink_skb_flags |= NETLINK_SKB_DST; } else { - dst_portid = nlk->dst_portid; - dst_group = nlk->dst_group; + /* Paired with WRITE_ONCE() in netlink_connect() */ + dst_portid = READ_ONCE(nlk->dst_portid); + dst_group = READ_ONCE(nlk->dst_group); } /* Paired with WRITE_ONCE() in netlink_insert() */ @@ -2488,19 +2509,24 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err, flags |= NLM_F_ACK_TLVS; skb = nlmsg_new(payload + tlvlen, GFP_KERNEL); - if (!skb) { - NETLINK_CB(in_skb).sk->sk_err = ENOBUFS; - sk_error_report(NETLINK_CB(in_skb).sk); - return; - } + if (!skb) + goto err_skb; rep = nlmsg_put(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, - NLMSG_ERROR, payload, flags); + NLMSG_ERROR, sizeof(*errmsg), flags); + if (!rep) + goto err_bad_put; errmsg = nlmsg_data(rep); errmsg->error = err; - unsafe_memcpy(&errmsg->msg, nlh, payload > sizeof(*errmsg) - ? nlh->nlmsg_len : sizeof(*nlh), - /* Bounds checked by the skb layer. */); + errmsg->msg = *nlh; + + if (!(flags & NLM_F_CAPPED)) { + if (!nlmsg_append(skb, nlmsg_len(nlh))) + goto err_bad_put; + + memcpy(nlmsg_data(&errmsg->msg), nlmsg_data(nlh), + nlmsg_len(nlh)); + } if (tlvlen) netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack); @@ -2508,6 +2534,14 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err, nlmsg_end(skb, rep); nlmsg_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).portid); + + return; + +err_bad_put: + nlmsg_free(skb); +err_skb: + NETLINK_CB(in_skb).sk->sk_err = ENOBUFS; + sk_error_report(NETLINK_CB(in_skb).sk); } EXPORT_SYMBOL(netlink_ack); diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c index 3e16527beb91..600993c80050 100644 --- a/net/netlink/genetlink.c +++ b/net/netlink/genetlink.c @@ -101,6 +101,17 @@ genl_op_fill_in_reject_policy(const struct genl_family *family, op->maxattr = 1; } +static void +genl_op_fill_in_reject_policy_split(const struct genl_family *family, + struct genl_split_ops *op) +{ + if (op->policy) + return; + + op->policy = genl_policy_reject_all; + op->maxattr = 1; +} + static const struct genl_family *genl_family_find_byid(unsigned int id) { return idr_find(&genl_fam_idr, id); @@ -118,10 +129,15 @@ static const struct genl_family *genl_family_find_byname(char *name) return NULL; } -static int genl_get_cmd_cnt(const struct genl_family *family) -{ - return family->n_ops + family->n_small_ops; -} +struct genl_op_iter { + const struct genl_family *family; + struct genl_split_ops doit; + struct genl_split_ops dumpit; + int cmd_idx; + int entry_idx; + u32 cmd; + u8 flags; +}; static void genl_op_from_full(const struct genl_family *family, unsigned int i, struct genl_ops *op) @@ -181,24 +197,187 @@ static int genl_get_cmd_small(u32 cmd, const struct genl_family *family, return -ENOENT; } -static int genl_get_cmd(u32 cmd, const struct genl_family *family, - struct genl_ops *op) +static void genl_op_from_split(struct genl_op_iter *iter) { - if (!genl_get_cmd_full(cmd, family, op)) - return 0; - return genl_get_cmd_small(cmd, family, op); + const struct genl_family *family = iter->family; + int i, cnt = 0; + + i = iter->entry_idx - family->n_ops - family->n_small_ops; + + if (family->split_ops[i + cnt].flags & GENL_CMD_CAP_DO) { + iter->doit = family->split_ops[i + cnt]; + genl_op_fill_in_reject_policy_split(family, &iter->doit); + cnt++; + } else { + memset(&iter->doit, 0, sizeof(iter->doit)); + } + + if (i + cnt < family->n_split_ops && + family->split_ops[i + cnt].flags & GENL_CMD_CAP_DUMP) { + iter->dumpit = family->split_ops[i + cnt]; + genl_op_fill_in_reject_policy_split(family, &iter->dumpit); + cnt++; + } else { + memset(&iter->dumpit, 0, sizeof(iter->dumpit)); + } + + WARN_ON(!cnt); + iter->entry_idx += cnt; } -static void genl_get_cmd_by_index(unsigned int i, - const struct genl_family *family, - struct genl_ops *op) +static int +genl_get_cmd_split(u32 cmd, u8 flag, const struct genl_family *family, + struct genl_split_ops *op) { - if (i < family->n_ops) - genl_op_from_full(family, i, op); - else if (i < family->n_ops + family->n_small_ops) - genl_op_from_small(family, i - family->n_ops, op); - else - WARN_ON_ONCE(1); + int i; + + for (i = 0; i < family->n_split_ops; i++) + if (family->split_ops[i].cmd == cmd && + family->split_ops[i].flags & flag) { + *op = family->split_ops[i]; + return 0; + } + + return -ENOENT; +} + +static int +genl_cmd_full_to_split(struct genl_split_ops *op, + const struct genl_family *family, + const struct genl_ops *full, u8 flags) +{ + if ((flags & GENL_CMD_CAP_DO && !full->doit) || + (flags & GENL_CMD_CAP_DUMP && !full->dumpit)) { + memset(op, 0, sizeof(*op)); + return -ENOENT; + } + + if (flags & GENL_CMD_CAP_DUMP) { + op->start = full->start; + op->dumpit = full->dumpit; + op->done = full->done; + } else { + op->pre_doit = family->pre_doit; + op->doit = full->doit; + op->post_doit = family->post_doit; + } + + if (flags & GENL_CMD_CAP_DUMP && + full->validate & GENL_DONT_VALIDATE_DUMP) { + op->policy = NULL; + op->maxattr = 0; + } else { + op->policy = full->policy; + op->maxattr = full->maxattr; + } + + op->cmd = full->cmd; + op->internal_flags = full->internal_flags; + op->flags = full->flags; + op->validate = full->validate; + + /* Make sure flags include the GENL_CMD_CAP_DO / GENL_CMD_CAP_DUMP */ + op->flags |= flags; + + return 0; +} + +/* Must make sure that op is initialized to 0 on failure */ +static int +genl_get_cmd(u32 cmd, u8 flags, const struct genl_family *family, + struct genl_split_ops *op) +{ + struct genl_ops full; + int err; + + err = genl_get_cmd_full(cmd, family, &full); + if (err == -ENOENT) + err = genl_get_cmd_small(cmd, family, &full); + /* Found one of legacy forms */ + if (err == 0) + return genl_cmd_full_to_split(op, family, &full, flags); + + err = genl_get_cmd_split(cmd, flags, family, op); + if (err) + memset(op, 0, sizeof(*op)); + return err; +} + +/* For policy dumping only, get ops of both do and dump. + * Fail if both are missing, genl_get_cmd() will zero-init in case of failure. + */ +static int +genl_get_cmd_both(u32 cmd, const struct genl_family *family, + struct genl_split_ops *doit, struct genl_split_ops *dumpit) +{ + int err1, err2; + + err1 = genl_get_cmd(cmd, GENL_CMD_CAP_DO, family, doit); + err2 = genl_get_cmd(cmd, GENL_CMD_CAP_DUMP, family, dumpit); + + return err1 && err2 ? -ENOENT : 0; +} + +static bool +genl_op_iter_init(const struct genl_family *family, struct genl_op_iter *iter) +{ + iter->family = family; + iter->cmd_idx = 0; + iter->entry_idx = 0; + + iter->flags = 0; + + return iter->family->n_ops + + iter->family->n_small_ops + + iter->family->n_split_ops; +} + +static bool genl_op_iter_next(struct genl_op_iter *iter) +{ + const struct genl_family *family = iter->family; + bool legacy_op = true; + struct genl_ops op; + + if (iter->entry_idx < family->n_ops) { + genl_op_from_full(family, iter->entry_idx, &op); + } else if (iter->entry_idx < family->n_ops + family->n_small_ops) { + genl_op_from_small(family, iter->entry_idx - family->n_ops, + &op); + } else if (iter->entry_idx < + family->n_ops + family->n_small_ops + family->n_split_ops) { + legacy_op = false; + /* updates entry_idx */ + genl_op_from_split(iter); + } else { + return false; + } + + iter->cmd_idx++; + + if (legacy_op) { + iter->entry_idx++; + + genl_cmd_full_to_split(&iter->doit, family, + &op, GENL_CMD_CAP_DO); + genl_cmd_full_to_split(&iter->dumpit, family, + &op, GENL_CMD_CAP_DUMP); + } + + iter->cmd = iter->doit.cmd | iter->dumpit.cmd; + iter->flags = iter->doit.flags | iter->dumpit.flags; + + return true; +} + +static void +genl_op_iter_copy(struct genl_op_iter *dst, struct genl_op_iter *src) +{ + *dst = *src; +} + +static unsigned int genl_op_iter_idx(struct genl_op_iter *iter) +{ + return iter->cmd_idx; } static int genl_allocate_reserve_groups(int n_groups, int *first_id) @@ -366,31 +545,72 @@ static void genl_unregister_mc_groups(const struct genl_family *family) } } +static bool genl_split_op_check(const struct genl_split_ops *op) +{ + if (WARN_ON(hweight8(op->flags & (GENL_CMD_CAP_DO | + GENL_CMD_CAP_DUMP)) != 1)) + return true; + return false; +} + static int genl_validate_ops(const struct genl_family *family) { - int i, j; + struct genl_op_iter i, j; + unsigned int s; if (WARN_ON(family->n_ops && !family->ops) || - WARN_ON(family->n_small_ops && !family->small_ops)) + WARN_ON(family->n_small_ops && !family->small_ops) || + WARN_ON(family->n_split_ops && !family->split_ops)) return -EINVAL; - for (i = 0; i < genl_get_cmd_cnt(family); i++) { - struct genl_ops op; - - genl_get_cmd_by_index(i, family, &op); - if (op.dumpit == NULL && op.doit == NULL) + for (genl_op_iter_init(family, &i); genl_op_iter_next(&i); ) { + if (!(i.flags & (GENL_CMD_CAP_DO | GENL_CMD_CAP_DUMP))) return -EINVAL; - if (WARN_ON(op.cmd >= family->resv_start_op && op.validate)) + + if (WARN_ON(i.cmd >= family->resv_start_op && + (i.doit.validate || i.dumpit.validate))) return -EINVAL; - for (j = i + 1; j < genl_get_cmd_cnt(family); j++) { - struct genl_ops op2; - genl_get_cmd_by_index(j, family, &op2); - if (op.cmd == op2.cmd) + genl_op_iter_copy(&j, &i); + while (genl_op_iter_next(&j)) { + if (i.cmd == j.cmd) return -EINVAL; } } + if (family->n_split_ops) { + if (genl_split_op_check(&family->split_ops[0])) + return -EINVAL; + } + + for (s = 1; s < family->n_split_ops; s++) { + const struct genl_split_ops *a, *b; + + a = &family->split_ops[s - 1]; + b = &family->split_ops[s]; + + if (genl_split_op_check(b)) + return -EINVAL; + + /* Check sort order */ + if (a->cmd < b->cmd) + continue; + + if (a->internal_flags != b->internal_flags || + ((a->flags ^ b->flags) & ~(GENL_CMD_CAP_DO | + GENL_CMD_CAP_DUMP))) { + WARN_ON(1); + return -EINVAL; + } + + if ((a->flags & GENL_CMD_CAP_DO) && + (b->flags & GENL_CMD_CAP_DUMP)) + continue; + + WARN_ON(1); + return -EINVAL; + } + return 0; } @@ -544,7 +764,7 @@ static struct nlattr ** genl_family_rcv_msg_attrs_parse(const struct genl_family *family, struct nlmsghdr *nlh, struct netlink_ext_ack *extack, - const struct genl_ops *ops, + const struct genl_split_ops *ops, int hdrlen, enum genl_validate_flags no_strict_flag) { @@ -580,22 +800,21 @@ struct genl_start_context { const struct genl_family *family; struct nlmsghdr *nlh; struct netlink_ext_ack *extack; - const struct genl_ops *ops; + const struct genl_split_ops *ops; int hdrlen; }; static int genl_start(struct netlink_callback *cb) { struct genl_start_context *ctx = cb->data; - const struct genl_ops *ops = ctx->ops; + const struct genl_split_ops *ops; struct genl_dumpit_info *info; struct nlattr **attrs = NULL; int rc = 0; - if (ops->validate & GENL_DONT_VALIDATE_DUMP) - goto no_attrs; - - if (ctx->nlh->nlmsg_len < nlmsg_msg_size(ctx->hdrlen)) + ops = ctx->ops; + if (!(ops->validate & GENL_DONT_VALIDATE_DUMP) && + ctx->nlh->nlmsg_len < nlmsg_msg_size(ctx->hdrlen)) return -EINVAL; attrs = genl_family_rcv_msg_attrs_parse(ctx->family, ctx->nlh, ctx->extack, @@ -604,7 +823,6 @@ static int genl_start(struct netlink_callback *cb) if (IS_ERR(attrs)) return PTR_ERR(attrs); -no_attrs: info = genl_dumpit_info_alloc(); if (!info) { genl_family_rcv_msg_attrs_free(attrs); @@ -633,7 +851,7 @@ no_attrs: static int genl_lock_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { - const struct genl_ops *ops = &genl_dumpit_info(cb)->op; + const struct genl_split_ops *ops = &genl_dumpit_info(cb)->op; int rc; genl_lock(); @@ -645,7 +863,7 @@ static int genl_lock_dumpit(struct sk_buff *skb, struct netlink_callback *cb) static int genl_lock_done(struct netlink_callback *cb) { const struct genl_dumpit_info *info = genl_dumpit_info(cb); - const struct genl_ops *ops = &info->op; + const struct genl_split_ops *ops = &info->op; int rc = 0; if (ops->done) { @@ -661,7 +879,7 @@ static int genl_lock_done(struct netlink_callback *cb) static int genl_parallel_done(struct netlink_callback *cb) { const struct genl_dumpit_info *info = genl_dumpit_info(cb); - const struct genl_ops *ops = &info->op; + const struct genl_split_ops *ops = &info->op; int rc = 0; if (ops->done) @@ -675,15 +893,12 @@ static int genl_family_rcv_msg_dumpit(const struct genl_family *family, struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack, - const struct genl_ops *ops, + const struct genl_split_ops *ops, int hdrlen, struct net *net) { struct genl_start_context ctx; int err; - if (!ops->dumpit) - return -EOPNOTSUPP; - ctx.family = family; ctx.nlh = nlh; ctx.extack = extack; @@ -721,16 +936,13 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family, struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack, - const struct genl_ops *ops, + const struct genl_split_ops *ops, int hdrlen, struct net *net) { struct nlattr **attrbuf; struct genl_info info; int err; - if (!ops->doit) - return -EOPNOTSUPP; - attrbuf = genl_family_rcv_msg_attrs_parse(family, nlh, extack, ops, hdrlen, GENL_DONT_VALIDATE_STRICT); @@ -747,16 +959,16 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family, genl_info_net_set(&info, net); memset(&info.user_ptr, 0, sizeof(info.user_ptr)); - if (family->pre_doit) { - err = family->pre_doit(ops, skb, &info); + if (ops->pre_doit) { + err = ops->pre_doit(ops, skb, &info); if (err) goto out; } err = ops->doit(skb, &info); - if (family->post_doit) - family->post_doit(ops, skb, &info); + if (ops->post_doit) + ops->post_doit(ops, skb, &info); out: genl_family_rcv_msg_attrs_free(attrbuf); @@ -801,8 +1013,9 @@ static int genl_family_rcv_msg(const struct genl_family *family, { struct net *net = sock_net(skb->sk); struct genlmsghdr *hdr = nlmsg_data(nlh); - struct genl_ops op; + struct genl_split_ops op; int hdrlen; + u8 flags; /* this family doesn't exist in this netns */ if (!family->netnsok && !net_eq(net, &init_net)) @@ -815,7 +1028,9 @@ static int genl_family_rcv_msg(const struct genl_family *family, if (genl_header_check(family, nlh, hdr, extack)) return -EINVAL; - if (genl_get_cmd(hdr->cmd, family, &op)) + flags = (nlh->nlmsg_flags & NLM_F_DUMP) == NLM_F_DUMP ? + GENL_CMD_CAP_DUMP : GENL_CMD_CAP_DO; + if (genl_get_cmd(hdr->cmd, flags, family, &op)) return -EOPNOTSUPP; if ((op.flags & GENL_ADMIN_PERM) && @@ -826,7 +1041,7 @@ static int genl_family_rcv_msg(const struct genl_family *family, !netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; - if ((nlh->nlmsg_flags & NLM_F_DUMP) == NLM_F_DUMP) + if (flags & GENL_CMD_CAP_DUMP) return genl_family_rcv_msg_dumpit(family, skb, nlh, extack, &op, hdrlen, net); else @@ -871,6 +1086,7 @@ static struct genl_family genl_ctrl; static int ctrl_fill_info(const struct genl_family *family, u32 portid, u32 seq, u32 flags, struct sk_buff *skb, u8 cmd) { + struct genl_op_iter i; void *hdr; hdr = genlmsg_put(skb, portid, seq, &genl_ctrl, flags, cmd); @@ -884,33 +1100,26 @@ static int ctrl_fill_info(const struct genl_family *family, u32 portid, u32 seq, nla_put_u32(skb, CTRL_ATTR_MAXATTR, family->maxattr)) goto nla_put_failure; - if (genl_get_cmd_cnt(family)) { + if (genl_op_iter_init(family, &i)) { struct nlattr *nla_ops; - int i; nla_ops = nla_nest_start_noflag(skb, CTRL_ATTR_OPS); if (nla_ops == NULL) goto nla_put_failure; - for (i = 0; i < genl_get_cmd_cnt(family); i++) { + while (genl_op_iter_next(&i)) { struct nlattr *nest; - struct genl_ops op; u32 op_flags; - genl_get_cmd_by_index(i, family, &op); - op_flags = op.flags; - if (op.dumpit) - op_flags |= GENL_CMD_CAP_DUMP; - if (op.doit) - op_flags |= GENL_CMD_CAP_DO; - if (op.policy) + op_flags = i.flags; + if (i.doit.policy || i.dumpit.policy) op_flags |= GENL_CMD_CAP_HASPOL; - nest = nla_nest_start_noflag(skb, i + 1); + nest = nla_nest_start_noflag(skb, genl_op_iter_idx(&i)); if (nest == NULL) goto nla_put_failure; - if (nla_put_u32(skb, CTRL_ATTR_OP_ID, op.cmd) || + if (nla_put_u32(skb, CTRL_ATTR_OP_ID, i.cmd) || nla_put_u32(skb, CTRL_ATTR_OP_FLAGS, op_flags)) goto nla_put_failure; @@ -1163,10 +1372,10 @@ static int genl_ctrl_event(int event, const struct genl_family *family, struct ctrl_dump_policy_ctx { struct netlink_policy_dump_state *state; const struct genl_family *rt; - unsigned int opidx; + struct genl_op_iter *op_iter; u32 op; u16 fam_id; - u8 policies:1, + u8 dump_map:1, single_op:1; }; @@ -1183,8 +1392,8 @@ static int ctrl_dumppolicy_start(struct netlink_callback *cb) struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; struct nlattr **tb = info->attrs; const struct genl_family *rt; - struct genl_ops op; - int err, i; + struct genl_op_iter i; + int err; BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); @@ -1208,40 +1417,73 @@ static int ctrl_dumppolicy_start(struct netlink_callback *cb) ctx->rt = rt; if (tb[CTRL_ATTR_OP]) { + struct genl_split_ops doit, dump; + ctx->single_op = true; ctx->op = nla_get_u32(tb[CTRL_ATTR_OP]); - err = genl_get_cmd(ctx->op, rt, &op); + err = genl_get_cmd_both(ctx->op, rt, &doit, &dump); if (err) { NL_SET_BAD_ATTR(cb->extack, tb[CTRL_ATTR_OP]); return err; } - if (!op.policy) + if (doit.policy) { + err = netlink_policy_dump_add_policy(&ctx->state, + doit.policy, + doit.maxattr); + if (err) + goto err_free_state; + } + if (dump.policy) { + err = netlink_policy_dump_add_policy(&ctx->state, + dump.policy, + dump.maxattr); + if (err) + goto err_free_state; + } + + if (!ctx->state) return -ENODATA; - return netlink_policy_dump_add_policy(&ctx->state, op.policy, - op.maxattr); + ctx->dump_map = 1; + return 0; } - for (i = 0; i < genl_get_cmd_cnt(rt); i++) { - genl_get_cmd_by_index(i, rt, &op); + ctx->op_iter = kmalloc(sizeof(*ctx->op_iter), GFP_KERNEL); + if (!ctx->op_iter) + return -ENOMEM; - if (op.policy) { + genl_op_iter_init(rt, ctx->op_iter); + ctx->dump_map = genl_op_iter_next(ctx->op_iter); + + for (genl_op_iter_init(rt, &i); genl_op_iter_next(&i); ) { + if (i.doit.policy) { err = netlink_policy_dump_add_policy(&ctx->state, - op.policy, - op.maxattr); + i.doit.policy, + i.doit.maxattr); + if (err) + goto err_free_state; + } + if (i.dumpit.policy) { + err = netlink_policy_dump_add_policy(&ctx->state, + i.dumpit.policy, + i.dumpit.maxattr); if (err) goto err_free_state; } } - if (!ctx->state) - return -ENODATA; + if (!ctx->state) { + err = -ENODATA; + goto err_free_op_iter; + } return 0; err_free_state: netlink_policy_dump_free(ctx->state); +err_free_op_iter: + kfree(ctx->op_iter); return err; } @@ -1265,7 +1507,8 @@ static void *ctrl_dumppolicy_prep(struct sk_buff *skb, static int ctrl_dumppolicy_put_op(struct sk_buff *skb, struct netlink_callback *cb, - struct genl_ops *op) + struct genl_split_ops *doit, + struct genl_split_ops *dumpit) { struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; struct nlattr *nest_pol, *nest_op; @@ -1273,10 +1516,7 @@ static int ctrl_dumppolicy_put_op(struct sk_buff *skb, int idx; /* skip if we have nothing to show */ - if (!op->policy) - return 0; - if (!op->doit && - (!op->dumpit || op->validate & GENL_DONT_VALIDATE_DUMP)) + if (!doit->policy && !dumpit->policy) return 0; hdr = ctrl_dumppolicy_prep(skb, cb); @@ -1287,21 +1527,26 @@ static int ctrl_dumppolicy_put_op(struct sk_buff *skb, if (!nest_pol) goto err; - nest_op = nla_nest_start(skb, op->cmd); + nest_op = nla_nest_start(skb, doit->cmd); if (!nest_op) goto err; - /* for now both do/dump are always the same */ - idx = netlink_policy_dump_get_policy_idx(ctx->state, - op->policy, - op->maxattr); + if (doit->policy) { + idx = netlink_policy_dump_get_policy_idx(ctx->state, + doit->policy, + doit->maxattr); - if (op->doit && nla_put_u32(skb, CTRL_ATTR_POLICY_DO, idx)) - goto err; + if (nla_put_u32(skb, CTRL_ATTR_POLICY_DO, idx)) + goto err; + } + if (dumpit->policy) { + idx = netlink_policy_dump_get_policy_idx(ctx->state, + dumpit->policy, + dumpit->maxattr); - if (op->dumpit && !(op->validate & GENL_DONT_VALIDATE_DUMP) && - nla_put_u32(skb, CTRL_ATTR_POLICY_DUMP, idx)) - goto err; + if (nla_put_u32(skb, CTRL_ATTR_POLICY_DUMP, idx)) + goto err; + } nla_nest_end(skb, nest_op); nla_nest_end(skb, nest_pol); @@ -1318,31 +1563,29 @@ static int ctrl_dumppolicy(struct sk_buff *skb, struct netlink_callback *cb) struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; void *hdr; - if (!ctx->policies) { - while (ctx->opidx < genl_get_cmd_cnt(ctx->rt)) { - struct genl_ops op; + if (ctx->dump_map) { + if (ctx->single_op) { + struct genl_split_ops doit, dumpit; - if (ctx->single_op) { - int err; + if (WARN_ON(genl_get_cmd_both(ctx->op, ctx->rt, + &doit, &dumpit))) + return -ENOENT; - err = genl_get_cmd(ctx->op, ctx->rt, &op); - if (WARN_ON(err)) - return skb->len; + if (ctrl_dumppolicy_put_op(skb, cb, &doit, &dumpit)) + return skb->len; - /* break out of the loop after this one */ - ctx->opidx = genl_get_cmd_cnt(ctx->rt); - } else { - genl_get_cmd_by_index(ctx->opidx, ctx->rt, &op); - } + /* done with the per-op policy index list */ + ctx->dump_map = 0; + } - if (ctrl_dumppolicy_put_op(skb, cb, &op)) + while (ctx->dump_map) { + if (ctrl_dumppolicy_put_op(skb, cb, + &ctx->op_iter->doit, + &ctx->op_iter->dumpit)) return skb->len; - ctx->opidx++; + ctx->dump_map = genl_op_iter_next(ctx->op_iter); } - - /* completed with the per-op policy index list */ - ctx->policies = true; } while (netlink_policy_dump_loop(ctx->state)) { @@ -1375,18 +1618,27 @@ static int ctrl_dumppolicy_done(struct netlink_callback *cb) { struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + kfree(ctx->op_iter); netlink_policy_dump_free(ctx->state); return 0; } -static const struct genl_ops genl_ctrl_ops[] = { +static const struct genl_split_ops genl_ctrl_ops[] = { { .cmd = CTRL_CMD_GETFAMILY, - .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .validate = GENL_DONT_VALIDATE_STRICT, .policy = ctrl_policy_family, .maxattr = ARRAY_SIZE(ctrl_policy_family) - 1, .doit = ctrl_getfamily, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = CTRL_CMD_GETFAMILY, + .validate = GENL_DONT_VALIDATE_DUMP, + .policy = ctrl_policy_family, + .maxattr = ARRAY_SIZE(ctrl_policy_family) - 1, .dumpit = ctrl_dumpfamily, + .flags = GENL_CMD_CAP_DUMP, }, { .cmd = CTRL_CMD_GETPOLICY, @@ -1395,6 +1647,7 @@ static const struct genl_ops genl_ctrl_ops[] = { .start = ctrl_dumppolicy_start, .dumpit = ctrl_dumppolicy, .done = ctrl_dumppolicy_done, + .flags = GENL_CMD_CAP_DUMP, }, }; @@ -1404,8 +1657,8 @@ static const struct genl_multicast_group genl_ctrl_groups[] = { static struct genl_family genl_ctrl __ro_after_init = { .module = THIS_MODULE, - .ops = genl_ctrl_ops, - .n_ops = ARRAY_SIZE(genl_ctrl_ops), + .split_ops = genl_ctrl_ops, + .n_split_ops = ARRAY_SIZE(genl_ctrl_ops), .resv_start_op = CTRL_CMD_GETPOLICY + 1, .mcgrps = genl_ctrl_groups, .n_mcgrps = ARRAY_SIZE(genl_ctrl_groups), diff --git a/net/netrom/af_netrom.c b/net/netrom/af_netrom.c index 6f7f4392cffb..5a4cb796150f 100644 --- a/net/netrom/af_netrom.c +++ b/net/netrom/af_netrom.c @@ -400,6 +400,11 @@ static int nr_listen(struct socket *sock, int backlog) struct sock *sk = sock->sk; lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + release_sock(sk); + return -EINVAL; + } + if (sk->sk_state != TCP_LISTEN) { memset(&nr_sk(sk)->user_addr, 0, AX25_ADDR_LEN); sk->sk_max_ack_backlog = backlog; diff --git a/net/netrom/nr_timer.c b/net/netrom/nr_timer.c index a8da88db7893..4e7c968cde2d 100644 --- a/net/netrom/nr_timer.c +++ b/net/netrom/nr_timer.c @@ -121,6 +121,7 @@ static void nr_heartbeat_expiry(struct timer_list *t) is accepted() it isn't 'dead' so doesn't get removed. */ if (sock_flag(sk, SOCK_DESTROY) || (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_DEAD))) { + sock_hold(sk); bh_unlock_sock(sk); nr_destroy_socket(sk); goto out; diff --git a/net/nfc/llcp_core.c b/net/nfc/llcp_core.c index 3364caabef8b..a27e1842b2a0 100644 --- a/net/nfc/llcp_core.c +++ b/net/nfc/llcp_core.c @@ -157,6 +157,7 @@ static void local_cleanup(struct nfc_llcp_local *local) cancel_work_sync(&local->rx_work); cancel_work_sync(&local->timeout_work); kfree_skb(local->rx_pending); + local->rx_pending = NULL; del_timer_sync(&local->sdreq_timer); cancel_work_sync(&local->sdreq_timeout_work); nfc_llcp_free_sdp_tlv_list(&local->pending_sdreqs); diff --git a/net/nfc/nci/core.c b/net/nfc/nci/core.c index 6a193cce2a75..fff755dde30d 100644 --- a/net/nfc/nci/core.c +++ b/net/nfc/nci/core.c @@ -24,6 +24,7 @@ #include <linux/sched.h> #include <linux/bitops.h> #include <linux/skbuff.h> +#include <linux/kcov.h> #include "../nfc.h" #include <net/nfc/nci.h> @@ -542,7 +543,7 @@ static int nci_open_device(struct nci_dev *ndev) skb_queue_purge(&ndev->tx_q); ndev->ops->close(ndev); - ndev->flags = 0; + ndev->flags &= BIT(NCI_UNREG); } done: @@ -1472,6 +1473,7 @@ static void nci_tx_work(struct work_struct *work) skb = skb_dequeue(&ndev->tx_q); if (!skb) return; + kcov_remote_start_common(skb_get_kcov_handle(skb)); /* Check if data flow control is used */ if (atomic_read(&conn_info->credits_cnt) != @@ -1487,6 +1489,7 @@ static void nci_tx_work(struct work_struct *work) mod_timer(&ndev->data_timer, jiffies + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + kcov_remote_stop(); } } @@ -1497,7 +1500,8 @@ static void nci_rx_work(struct work_struct *work) struct nci_dev *ndev = container_of(work, struct nci_dev, rx_work); struct sk_buff *skb; - while ((skb = skb_dequeue(&ndev->rx_q))) { + for (; (skb = skb_dequeue(&ndev->rx_q)); kcov_remote_stop()) { + kcov_remote_start_common(skb_get_kcov_handle(skb)); /* Send copy to sniffer */ nfc_send_to_raw_sock(ndev->nfc_dev, skb, @@ -1551,6 +1555,7 @@ static void nci_cmd_work(struct work_struct *work) if (!skb) return; + kcov_remote_start_common(skb_get_kcov_handle(skb)); atomic_dec(&ndev->cmd_cnt); pr_debug("NCI TX: MT=cmd, PBF=%d, GID=0x%x, OID=0x%x, plen=%d\n", @@ -1563,6 +1568,7 @@ static void nci_cmd_work(struct work_struct *work) mod_timer(&ndev->cmd_timer, jiffies + msecs_to_jiffies(NCI_CMD_TIMEOUT)); + kcov_remote_stop(); } } diff --git a/net/nfc/nci/data.c b/net/nfc/nci/data.c index aa5e712adf07..3d36ea5701f0 100644 --- a/net/nfc/nci/data.c +++ b/net/nfc/nci/data.c @@ -279,8 +279,10 @@ void nci_rx_data_packet(struct nci_dev *ndev, struct sk_buff *skb) nci_plen(skb->data)); conn_info = nci_get_conn_info_by_conn_id(ndev, nci_conn_id(skb->data)); - if (!conn_info) + if (!conn_info) { + kfree_skb(skb); return; + } /* strip the nci data header */ skb_pull(skb, NCI_DATA_HDR_SIZE); diff --git a/net/nfc/nci/hci.c b/net/nfc/nci/hci.c index 78c4b6addf15..de175318a3a0 100644 --- a/net/nfc/nci/hci.c +++ b/net/nfc/nci/hci.c @@ -14,6 +14,7 @@ #include <net/nfc/nci.h> #include <net/nfc/nci_core.h> #include <linux/nfc.h> +#include <linux/kcov.h> struct nci_data { u8 conn_id; @@ -409,7 +410,8 @@ static void nci_hci_msg_rx_work(struct work_struct *work) const struct nci_hcp_message *message; u8 pipe, type, instruction; - while ((skb = skb_dequeue(&hdev->msg_rx_queue)) != NULL) { + for (; (skb = skb_dequeue(&hdev->msg_rx_queue)); kcov_remote_stop()) { + kcov_remote_start_common(skb_get_kcov_handle(skb)); pipe = NCI_HCP_MSG_GET_PIPE(skb->data[0]); skb_pull(skb, NCI_HCI_HCP_PACKET_HEADER_LEN); message = (struct nci_hcp_message *)skb->data; diff --git a/net/nfc/nci/ntf.c b/net/nfc/nci/ntf.c index 282c51051dcc..994a0a1efb58 100644 --- a/net/nfc/nci/ntf.c +++ b/net/nfc/nci/ntf.c @@ -240,6 +240,8 @@ static int nci_add_new_protocol(struct nci_dev *ndev, target->sens_res = nfca_poll->sens_res; target->sel_res = nfca_poll->sel_res; target->nfcid1_len = nfca_poll->nfcid1_len; + if (target->nfcid1_len > ARRAY_SIZE(target->nfcid1)) + return -EPROTO; if (target->nfcid1_len > 0) { memcpy(target->nfcid1, nfca_poll->nfcid1, target->nfcid1_len); @@ -248,6 +250,8 @@ static int nci_add_new_protocol(struct nci_dev *ndev, nfcb_poll = (struct rf_tech_specific_params_nfcb_poll *)params; target->sensb_res_len = nfcb_poll->sensb_res_len; + if (target->sensb_res_len > ARRAY_SIZE(target->sensb_res)) + return -EPROTO; if (target->sensb_res_len > 0) { memcpy(target->sensb_res, nfcb_poll->sensb_res, target->sensb_res_len); @@ -256,6 +260,8 @@ static int nci_add_new_protocol(struct nci_dev *ndev, nfcf_poll = (struct rf_tech_specific_params_nfcf_poll *)params; target->sensf_res_len = nfcf_poll->sensf_res_len; + if (target->sensf_res_len > ARRAY_SIZE(target->sensf_res)) + return -EPROTO; if (target->sensf_res_len > 0) { memcpy(target->sensf_res, nfcf_poll->sensf_res, target->sensf_res_len); diff --git a/net/nfc/netlink.c b/net/nfc/netlink.c index 9d91087b9399..1fc339084d89 100644 --- a/net/nfc/netlink.c +++ b/net/nfc/netlink.c @@ -1497,6 +1497,7 @@ static int nfc_genl_se_io(struct sk_buff *skb, struct genl_info *info) u32 dev_idx, se_idx; u8 *apdu; size_t apdu_len; + int rc; if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || !info->attrs[NFC_ATTR_SE_INDEX] || @@ -1510,25 +1511,37 @@ static int nfc_genl_se_io(struct sk_buff *skb, struct genl_info *info) if (!dev) return -ENODEV; - if (!dev->ops || !dev->ops->se_io) - return -ENOTSUPP; + if (!dev->ops || !dev->ops->se_io) { + rc = -EOPNOTSUPP; + goto put_dev; + } apdu_len = nla_len(info->attrs[NFC_ATTR_SE_APDU]); - if (apdu_len == 0) - return -EINVAL; + if (apdu_len == 0) { + rc = -EINVAL; + goto put_dev; + } apdu = nla_data(info->attrs[NFC_ATTR_SE_APDU]); - if (!apdu) - return -EINVAL; + if (!apdu) { + rc = -EINVAL; + goto put_dev; + } ctx = kzalloc(sizeof(struct se_io_ctx), GFP_KERNEL); - if (!ctx) - return -ENOMEM; + if (!ctx) { + rc = -ENOMEM; + goto put_dev; + } ctx->dev_idx = dev_idx; ctx->se_idx = se_idx; - return nfc_se_io(dev, se_idx, apdu, apdu_len, se_io_cb, ctx); + rc = nfc_se_io(dev, se_idx, apdu, apdu_len, se_io_cb, ctx); + +put_dev: + nfc_put_device(dev); + return rc; } static int nfc_genl_vendor_cmd(struct sk_buff *skb, @@ -1551,14 +1564,21 @@ static int nfc_genl_vendor_cmd(struct sk_buff *skb, subcmd = nla_get_u32(info->attrs[NFC_ATTR_VENDOR_SUBCMD]); dev = nfc_get_device(dev_idx); - if (!dev || !dev->vendor_cmds || !dev->n_vendor_cmds) + if (!dev) return -ENODEV; + if (!dev->vendor_cmds || !dev->n_vendor_cmds) { + err = -ENODEV; + goto put_dev; + } + if (info->attrs[NFC_ATTR_VENDOR_DATA]) { data = nla_data(info->attrs[NFC_ATTR_VENDOR_DATA]); data_len = nla_len(info->attrs[NFC_ATTR_VENDOR_DATA]); - if (data_len == 0) - return -EINVAL; + if (data_len == 0) { + err = -EINVAL; + goto put_dev; + } } else { data = NULL; data_len = 0; @@ -1573,10 +1593,14 @@ static int nfc_genl_vendor_cmd(struct sk_buff *skb, dev->cur_cmd_info = info; err = cmd->doit(dev, data, data_len); dev->cur_cmd_info = NULL; - return err; + goto put_dev; } - return -EOPNOTSUPP; + err = -EOPNOTSUPP; + +put_dev: + nfc_put_device(dev); + return err; } /* message building helper */ diff --git a/net/nfc/rawsock.c b/net/nfc/rawsock.c index 8dd569765f96..5125392bb68e 100644 --- a/net/nfc/rawsock.c +++ b/net/nfc/rawsock.c @@ -12,6 +12,7 @@ #include <net/tcp_states.h> #include <linux/nfc.h> #include <linux/export.h> +#include <linux/kcov.h> #include "nfc.h" @@ -189,6 +190,7 @@ static void rawsock_tx_work(struct work_struct *work) } skb = skb_dequeue(&sk->sk_write_queue); + kcov_remote_start_common(skb_get_kcov_handle(skb)); sock_hold(sk); rc = nfc_data_exchange(dev, target_idx, skb, @@ -197,6 +199,7 @@ static void rawsock_tx_work(struct work_struct *work) rawsock_report_error(sk, rc); sock_put(sk); } + kcov_remote_stop(); } static int rawsock_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) diff --git a/net/openvswitch/Kconfig b/net/openvswitch/Kconfig index 15bd287f5cbd..747d537a3f06 100644 --- a/net/openvswitch/Kconfig +++ b/net/openvswitch/Kconfig @@ -15,6 +15,7 @@ config OPENVSWITCH select NET_MPLS_GSO select DST_CACHE select NET_NSH + select NF_NAT_OVS if NF_NAT help Open vSwitch is a multilayer Ethernet switch targeted at virtualized environments. In addition to supporting a variety of features diff --git a/net/openvswitch/conntrack.c b/net/openvswitch/conntrack.c index c7b10234cf7c..c8b137649ca4 100644 --- a/net/openvswitch/conntrack.c +++ b/net/openvswitch/conntrack.c @@ -152,7 +152,7 @@ static u8 ovs_ct_get_state(enum ip_conntrack_info ctinfo) static u32 ovs_ct_get_mark(const struct nf_conn *ct) { #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) - return ct ? ct->mark : 0; + return ct ? READ_ONCE(ct->mark) : 0; #else return 0; #endif @@ -340,9 +340,9 @@ static int ovs_ct_set_mark(struct nf_conn *ct, struct sw_flow_key *key, #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) u32 new_mark; - new_mark = ct_mark | (ct->mark & ~(mask)); - if (ct->mark != new_mark) { - ct->mark = new_mark; + new_mark = ct_mark | (READ_ONCE(ct->mark) & ~(mask)); + if (READ_ONCE(ct->mark) != new_mark) { + WRITE_ONCE(ct->mark, new_mark); if (nf_ct_is_confirmed(ct)) nf_conntrack_event_cache(IPCT_MARK, ct); key->ct.mark = new_mark; @@ -434,65 +434,6 @@ static int ovs_ct_set_labels(struct nf_conn *ct, struct sw_flow_key *key, return 0; } -/* 'skb' should already be pulled to nh_ofs. */ -static int ovs_ct_helper(struct sk_buff *skb, u16 proto) -{ - const struct nf_conntrack_helper *helper; - const struct nf_conn_help *help; - enum ip_conntrack_info ctinfo; - unsigned int protoff; - struct nf_conn *ct; - int err; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return NF_ACCEPT; - - help = nfct_help(ct); - if (!help) - return NF_ACCEPT; - - helper = rcu_dereference(help->helper); - if (!helper) - return NF_ACCEPT; - - switch (proto) { - case NFPROTO_IPV4: - protoff = ip_hdrlen(skb); - break; - case NFPROTO_IPV6: { - u8 nexthdr = ipv6_hdr(skb)->nexthdr; - __be16 frag_off; - int ofs; - - ofs = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, - &frag_off); - if (ofs < 0 || (frag_off & htons(~0x7)) != 0) { - pr_debug("proto header not found\n"); - return NF_ACCEPT; - } - protoff = ofs; - break; - } - default: - WARN_ONCE(1, "helper invoked on non-IP family!"); - return NF_DROP; - } - - err = helper->help(skb, protoff, ct, ctinfo); - if (err != NF_ACCEPT) - return err; - - /* Adjust seqs after helper. This is needed due to some helpers (e.g., - * FTP with NAT) adusting the TCP payload size when mangling IP - * addresses and/or port numbers in the text-based control connection. - */ - if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && - !nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) - return NF_DROP; - return NF_ACCEPT; -} - /* Returns 0 on success, -EINPROGRESS if 'skb' is stolen, or other nonzero * value if 'skb' is freed. */ @@ -785,147 +726,27 @@ static void ovs_nat_update_key(struct sw_flow_key *key, } } -/* Modelled after nf_nat_ipv[46]_fn(). - * range is only used for new, uninitialized NAT state. - * Returns either NF_ACCEPT or NF_DROP. - */ -static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, - enum ip_conntrack_info ctinfo, - const struct nf_nat_range2 *range, - enum nf_nat_manip_type maniptype, struct sw_flow_key *key) -{ - int hooknum, nh_off, err = NF_ACCEPT; - - nh_off = skb_network_offset(skb); - skb_pull_rcsum(skb, nh_off); - - /* See HOOK2MANIP(). */ - if (maniptype == NF_NAT_MANIP_SRC) - hooknum = NF_INET_LOCAL_IN; /* Source NAT */ - else - hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */ - - switch (ctinfo) { - case IP_CT_RELATED: - case IP_CT_RELATED_REPLY: - if (IS_ENABLED(CONFIG_NF_NAT) && - skb->protocol == htons(ETH_P_IP) && - ip_hdr(skb)->protocol == IPPROTO_ICMP) { - if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, - hooknum)) - err = NF_DROP; - goto push; - } else if (IS_ENABLED(CONFIG_IPV6) && - skb->protocol == htons(ETH_P_IPV6)) { - __be16 frag_off; - u8 nexthdr = ipv6_hdr(skb)->nexthdr; - int hdrlen = ipv6_skip_exthdr(skb, - sizeof(struct ipv6hdr), - &nexthdr, &frag_off); - - if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { - if (!nf_nat_icmpv6_reply_translation(skb, ct, - ctinfo, - hooknum, - hdrlen)) - err = NF_DROP; - goto push; - } - } - /* Non-ICMP, fall thru to initialize if needed. */ - fallthrough; - case IP_CT_NEW: - /* Seen it before? This can happen for loopback, retrans, - * or local packets. - */ - if (!nf_nat_initialized(ct, maniptype)) { - /* Initialize according to the NAT action. */ - err = (range && range->flags & NF_NAT_RANGE_MAP_IPS) - /* Action is set up to establish a new - * mapping. - */ - ? nf_nat_setup_info(ct, range, maniptype) - : nf_nat_alloc_null_binding(ct, hooknum); - if (err != NF_ACCEPT) - goto push; - } - break; - - case IP_CT_ESTABLISHED: - case IP_CT_ESTABLISHED_REPLY: - break; - - default: - err = NF_DROP; - goto push; - } - - err = nf_nat_packet(ct, ctinfo, hooknum, skb); -push: - skb_push_rcsum(skb, nh_off); - - /* Update the flow key if NAT successful. */ - if (err == NF_ACCEPT) - ovs_nat_update_key(key, skb, maniptype); - - return err; -} - /* Returns NF_DROP if the packet should be dropped, NF_ACCEPT otherwise. */ static int ovs_ct_nat(struct net *net, struct sw_flow_key *key, const struct ovs_conntrack_info *info, struct sk_buff *skb, struct nf_conn *ct, enum ip_conntrack_info ctinfo) { - enum nf_nat_manip_type maniptype; - int err; + int err, action = 0; - /* Add NAT extension if not confirmed yet. */ - if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct)) - return NF_ACCEPT; /* Can't NAT. */ + if (!(info->nat & OVS_CT_NAT)) + return NF_ACCEPT; + if (info->nat & OVS_CT_SRC_NAT) + action |= BIT(NF_NAT_MANIP_SRC); + if (info->nat & OVS_CT_DST_NAT) + action |= BIT(NF_NAT_MANIP_DST); - /* Determine NAT type. - * Check if the NAT type can be deduced from the tracked connection. - * Make sure new expected connections (IP_CT_RELATED) are NATted only - * when committing. - */ - if (info->nat & OVS_CT_NAT && ctinfo != IP_CT_NEW && - ct->status & IPS_NAT_MASK && - (ctinfo != IP_CT_RELATED || info->commit)) { - /* NAT an established or related connection like before. */ - if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) - /* This is the REPLY direction for a connection - * for which NAT was applied in the forward - * direction. Do the reverse NAT. - */ - maniptype = ct->status & IPS_SRC_NAT - ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC; - else - maniptype = ct->status & IPS_SRC_NAT - ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST; - } else if (info->nat & OVS_CT_SRC_NAT) { - maniptype = NF_NAT_MANIP_SRC; - } else if (info->nat & OVS_CT_DST_NAT) { - maniptype = NF_NAT_MANIP_DST; - } else { - return NF_ACCEPT; /* Connection is not NATed. */ - } - err = ovs_ct_nat_execute(skb, ct, ctinfo, &info->range, maniptype, key); - - if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) { - if (ct->status & IPS_SRC_NAT) { - if (maniptype == NF_NAT_MANIP_SRC) - maniptype = NF_NAT_MANIP_DST; - else - maniptype = NF_NAT_MANIP_SRC; - - err = ovs_ct_nat_execute(skb, ct, ctinfo, &info->range, - maniptype, key); - } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { - err = ovs_ct_nat_execute(skb, ct, ctinfo, NULL, - NF_NAT_MANIP_SRC, key); - } - } + err = nf_ct_nat(skb, ct, ctinfo, &action, &info->range, info->commit); + + if (action & BIT(NF_NAT_MANIP_SRC)) + ovs_nat_update_key(key, skb, NF_NAT_MANIP_SRC); + if (action & BIT(NF_NAT_MANIP_DST)) + ovs_nat_update_key(key, skb, NF_NAT_MANIP_DST); return err; } @@ -1038,7 +859,7 @@ static int __ovs_ct_lookup(struct net *net, struct sw_flow_key *key, */ if ((nf_ct_is_confirmed(ct) ? !cached || add_helper : info->commit) && - ovs_ct_helper(skb, info->family) != NF_ACCEPT) { + nf_ct_helper(skb, ct, ctinfo, info->family) != NF_ACCEPT) { return -EINVAL; } @@ -1350,43 +1171,6 @@ int ovs_ct_clear(struct sk_buff *skb, struct sw_flow_key *key) return 0; } -static int ovs_ct_add_helper(struct ovs_conntrack_info *info, const char *name, - const struct sw_flow_key *key, bool log) -{ - struct nf_conntrack_helper *helper; - struct nf_conn_help *help; - int ret = 0; - - helper = nf_conntrack_helper_try_module_get(name, info->family, - key->ip.proto); - if (!helper) { - OVS_NLERR(log, "Unknown helper \"%s\"", name); - return -EINVAL; - } - - help = nf_ct_helper_ext_add(info->ct, GFP_KERNEL); - if (!help) { - nf_conntrack_helper_put(helper); - return -ENOMEM; - } - -#if IS_ENABLED(CONFIG_NF_NAT) - if (info->nat) { - ret = nf_nat_helper_try_module_get(name, info->family, - key->ip.proto); - if (ret) { - nf_conntrack_helper_put(helper); - OVS_NLERR(log, "Failed to load \"%s\" NAT helper, error: %d", - name, ret); - return ret; - } - } -#endif - rcu_assign_pointer(help->helper, helper); - info->helper = helper; - return ret; -} - #if IS_ENABLED(CONFIG_NF_NAT) static int parse_nat(const struct nlattr *attr, struct ovs_conntrack_info *info, bool log) @@ -1720,9 +1504,12 @@ int ovs_ct_copy_action(struct net *net, const struct nlattr *attr, } if (helper) { - err = ovs_ct_add_helper(&ct_info, helper, key, log); - if (err) + err = nf_ct_add_helper(ct_info.ct, helper, ct_info.family, + key->ip.proto, ct_info.nat, &ct_info.helper); + if (err) { + OVS_NLERR(log, "Failed to add %s helper %d", helper, err); goto err_free_ct; + } } err = ovs_nla_add_action(sfa, OVS_ACTION_ATTR_CT, &ct_info, diff --git a/net/openvswitch/datapath.c b/net/openvswitch/datapath.c index 8b84869eb2ac..fcee6012293b 100644 --- a/net/openvswitch/datapath.c +++ b/net/openvswitch/datapath.c @@ -209,6 +209,26 @@ static struct vport *new_vport(const struct vport_parms *parms) return vport; } +static void ovs_vport_update_upcall_stats(struct sk_buff *skb, + const struct dp_upcall_info *upcall_info, + bool upcall_result) +{ + struct vport *p = OVS_CB(skb)->input_vport; + struct vport_upcall_stats_percpu *stats; + + if (upcall_info->cmd != OVS_PACKET_CMD_MISS && + upcall_info->cmd != OVS_PACKET_CMD_ACTION) + return; + + stats = this_cpu_ptr(p->upcall_stats); + u64_stats_update_begin(&stats->syncp); + if (upcall_result) + u64_stats_inc(&stats->n_success); + else + u64_stats_inc(&stats->n_fail); + u64_stats_update_end(&stats->syncp); +} + void ovs_dp_detach_port(struct vport *p) { ASSERT_OVSL(); @@ -216,6 +236,9 @@ void ovs_dp_detach_port(struct vport *p) /* First drop references to device. */ hlist_del_rcu(&p->dp_hash_node); + /* Free percpu memory */ + free_percpu(p->upcall_stats); + /* Then destroy it. */ ovs_vport_del(p); } @@ -305,6 +328,8 @@ int ovs_dp_upcall(struct datapath *dp, struct sk_buff *skb, err = queue_userspace_packet(dp, skb, key, upcall_info, cutlen); else err = queue_gso_packets(dp, skb, key, upcall_info, cutlen); + + ovs_vport_update_upcall_stats(skb, upcall_info, !err); if (err) goto err; @@ -716,9 +741,9 @@ static void get_dp_stats(const struct datapath *dp, struct ovs_dp_stats *stats, percpu_stats = per_cpu_ptr(dp->stats_percpu, i); do { - start = u64_stats_fetch_begin_irq(&percpu_stats->syncp); + start = u64_stats_fetch_begin(&percpu_stats->syncp); local_stats = *percpu_stats; - } while (u64_stats_fetch_retry_irq(&percpu_stats->syncp, start)); + } while (u64_stats_fetch_retry(&percpu_stats->syncp, start)); stats->n_hit += local_stats.n_hit; stats->n_missed += local_stats.n_missed; @@ -948,6 +973,7 @@ static int ovs_flow_cmd_new(struct sk_buff *skb, struct genl_info *info) struct sw_flow_mask mask; struct sk_buff *reply; struct datapath *dp; + struct sw_flow_key *key; struct sw_flow_actions *acts; struct sw_flow_match match; u32 ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); @@ -975,30 +1001,32 @@ static int ovs_flow_cmd_new(struct sk_buff *skb, struct genl_info *info) } /* Extract key. */ - ovs_match_init(&match, &new_flow->key, false, &mask); + key = kzalloc(sizeof(*key), GFP_KERNEL); + if (!key) { + error = -ENOMEM; + goto err_kfree_flow; + } + + ovs_match_init(&match, key, false, &mask); error = ovs_nla_get_match(net, &match, a[OVS_FLOW_ATTR_KEY], a[OVS_FLOW_ATTR_MASK], log); if (error) - goto err_kfree_flow; + goto err_kfree_key; + + ovs_flow_mask_key(&new_flow->key, key, true, &mask); /* Extract flow identifier. */ error = ovs_nla_get_identifier(&new_flow->id, a[OVS_FLOW_ATTR_UFID], - &new_flow->key, log); + key, log); if (error) - goto err_kfree_flow; - - /* unmasked key is needed to match when ufid is not used. */ - if (ovs_identifier_is_key(&new_flow->id)) - match.key = new_flow->id.unmasked_key; - - ovs_flow_mask_key(&new_flow->key, &new_flow->key, true, &mask); + goto err_kfree_key; /* Validate actions. */ error = ovs_nla_copy_actions(net, a[OVS_FLOW_ATTR_ACTIONS], &new_flow->key, &acts, log); if (error) { OVS_NLERR(log, "Flow actions may not be safe on all matching packets."); - goto err_kfree_flow; + goto err_kfree_key; } reply = ovs_flow_cmd_alloc_info(acts, &new_flow->id, info, false, @@ -1019,7 +1047,7 @@ static int ovs_flow_cmd_new(struct sk_buff *skb, struct genl_info *info) if (ovs_identifier_is_ufid(&new_flow->id)) flow = ovs_flow_tbl_lookup_ufid(&dp->table, &new_flow->id); if (!flow) - flow = ovs_flow_tbl_lookup(&dp->table, &new_flow->key); + flow = ovs_flow_tbl_lookup(&dp->table, key); if (likely(!flow)) { rcu_assign_pointer(new_flow->sf_acts, acts); @@ -1089,6 +1117,8 @@ static int ovs_flow_cmd_new(struct sk_buff *skb, struct genl_info *info) if (reply) ovs_notify(&dp_flow_genl_family, reply, info); + + kfree(key); return 0; err_unlock_ovs: @@ -1096,6 +1126,8 @@ err_unlock_ovs: kfree_skb(reply); err_kfree_acts: ovs_nla_free_flow_actions(acts); +err_kfree_key: + kfree(key); err_kfree_flow: ovs_flow_free(new_flow, false); error: @@ -1826,6 +1858,12 @@ static int ovs_dp_cmd_new(struct sk_buff *skb, struct genl_info *info) goto err_destroy_portids; } + vport->upcall_stats = netdev_alloc_pcpu_stats(struct vport_upcall_stats_percpu); + if (!vport->upcall_stats) { + err = -ENOMEM; + goto err_destroy_vport; + } + err = ovs_dp_cmd_fill_info(dp, reply, info->snd_portid, info->snd_seq, 0, OVS_DP_CMD_NEW); BUG_ON(err < 0); @@ -1838,6 +1876,8 @@ static int ovs_dp_cmd_new(struct sk_buff *skb, struct genl_info *info) ovs_notify(&dp_datapath_genl_family, reply, info); return 0; +err_destroy_vport: + ovs_dp_detach_port(vport); err_destroy_portids: kfree(rcu_dereference_raw(dp->upcall_portids)); err_unlock_and_destroy_meters: @@ -2098,6 +2138,9 @@ static int ovs_vport_cmd_fill_info(struct vport *vport, struct sk_buff *skb, OVS_VPORT_ATTR_PAD)) goto nla_put_failure; + if (ovs_vport_get_upcall_stats(vport, skb)) + goto nla_put_failure; + if (ovs_vport_get_upcall_portids(vport, skb)) goto nla_put_failure; @@ -2279,6 +2322,12 @@ restart: goto exit_unlock_free; } + vport->upcall_stats = netdev_alloc_pcpu_stats(struct vport_upcall_stats_percpu); + if (!vport->upcall_stats) { + err = -ENOMEM; + goto exit_unlock_free_vport; + } + err = ovs_vport_cmd_fill_info(vport, reply, genl_info_net(info), info->snd_portid, info->snd_seq, 0, OVS_VPORT_CMD_NEW, GFP_KERNEL); @@ -2296,6 +2345,8 @@ restart: ovs_notify(&dp_vport_genl_family, reply, info); return 0; +exit_unlock_free_vport: + ovs_dp_detach_port(vport); exit_unlock_free: ovs_unlock(); kfree_skb(reply); @@ -2508,6 +2559,7 @@ static const struct nla_policy vport_policy[OVS_VPORT_ATTR_MAX + 1] = { [OVS_VPORT_ATTR_OPTIONS] = { .type = NLA_NESTED }, [OVS_VPORT_ATTR_IFINDEX] = { .type = NLA_U32 }, [OVS_VPORT_ATTR_NETNSID] = { .type = NLA_S32 }, + [OVS_VPORT_ATTR_UPCALL_STATS] = { .type = NLA_NESTED }, }; static const struct genl_small_ops dp_vport_genl_ops[] = { diff --git a/net/openvswitch/flow_netlink.c b/net/openvswitch/flow_netlink.c index 4a07ab094a84..ead5418c126e 100644 --- a/net/openvswitch/flow_netlink.c +++ b/net/openvswitch/flow_netlink.c @@ -2309,7 +2309,7 @@ static struct sw_flow_actions *nla_alloc_flow_actions(int size) WARN_ON_ONCE(size > MAX_ACTIONS_BUFSIZE); - sfa = kmalloc(sizeof(*sfa) + size, GFP_KERNEL); + sfa = kmalloc(kmalloc_size_roundup(sizeof(*sfa) + size), GFP_KERNEL); if (!sfa) return ERR_PTR(-ENOMEM); diff --git a/net/openvswitch/flow_table.c b/net/openvswitch/flow_table.c index d4a2db0b2299..0a0e4c283f02 100644 --- a/net/openvswitch/flow_table.c +++ b/net/openvswitch/flow_table.c @@ -205,9 +205,9 @@ static void tbl_mask_array_reset_counters(struct mask_array *ma) stats = per_cpu_ptr(ma->masks_usage_stats, cpu); do { - start = u64_stats_fetch_begin_irq(&stats->syncp); + start = u64_stats_fetch_begin(&stats->syncp); counter = stats->usage_cntrs[i]; - } while (u64_stats_fetch_retry_irq(&stats->syncp, start)); + } while (u64_stats_fetch_retry(&stats->syncp, start)); ma->masks_usage_zero_cntr[i] += counter; } @@ -1136,10 +1136,9 @@ void ovs_flow_masks_rebalance(struct flow_table *table) stats = per_cpu_ptr(ma->masks_usage_stats, cpu); do { - start = u64_stats_fetch_begin_irq(&stats->syncp); + start = u64_stats_fetch_begin(&stats->syncp); counter = stats->usage_cntrs[i]; - } while (u64_stats_fetch_retry_irq(&stats->syncp, - start)); + } while (u64_stats_fetch_retry(&stats->syncp, start)); masks_and_count[i].counter += counter; } diff --git a/net/openvswitch/vport-geneve.c b/net/openvswitch/vport-geneve.c index 89a8e1501809..b10e1602c6b1 100644 --- a/net/openvswitch/vport-geneve.c +++ b/net/openvswitch/vport-geneve.c @@ -91,7 +91,7 @@ static struct vport *geneve_tnl_create(const struct vport_parms *parms) err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); if (err < 0) { - rtnl_delete_link(dev); + rtnl_delete_link(dev, 0, NULL); rtnl_unlock(); ovs_vport_free(vport); goto error; diff --git a/net/openvswitch/vport-gre.c b/net/openvswitch/vport-gre.c index e6b5e76a962a..4014c9b5eb79 100644 --- a/net/openvswitch/vport-gre.c +++ b/net/openvswitch/vport-gre.c @@ -57,7 +57,7 @@ static struct vport *gre_tnl_create(const struct vport_parms *parms) err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); if (err < 0) { - rtnl_delete_link(dev); + rtnl_delete_link(dev, 0, NULL); rtnl_unlock(); ovs_vport_free(vport); return ERR_PTR(err); diff --git a/net/openvswitch/vport-netdev.c b/net/openvswitch/vport-netdev.c index 2f61d5bdce1a..903537a5da22 100644 --- a/net/openvswitch/vport-netdev.c +++ b/net/openvswitch/vport-netdev.c @@ -172,7 +172,7 @@ void ovs_netdev_tunnel_destroy(struct vport *vport) * if it's not already shutting down. */ if (vport->dev->reg_state == NETREG_REGISTERED) - rtnl_delete_link(vport->dev); + rtnl_delete_link(vport->dev, 0, NULL); netdev_put(vport->dev, &vport->dev_tracker); vport->dev = NULL; rtnl_unlock(); diff --git a/net/openvswitch/vport-vxlan.c b/net/openvswitch/vport-vxlan.c index 188e9c1360a1..0b881b043bcf 100644 --- a/net/openvswitch/vport-vxlan.c +++ b/net/openvswitch/vport-vxlan.c @@ -120,7 +120,7 @@ static struct vport *vxlan_tnl_create(const struct vport_parms *parms) err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); if (err < 0) { - rtnl_delete_link(dev); + rtnl_delete_link(dev, 0, NULL); rtnl_unlock(); ovs_vport_free(vport); goto error; diff --git a/net/openvswitch/vport.c b/net/openvswitch/vport.c index 82a74f998966..7e0f5c45b512 100644 --- a/net/openvswitch/vport.c +++ b/net/openvswitch/vport.c @@ -285,6 +285,56 @@ void ovs_vport_get_stats(struct vport *vport, struct ovs_vport_stats *stats) } /** + * ovs_vport_get_upcall_stats - retrieve upcall stats + * + * @vport: vport from which to retrieve the stats. + * @skb: sk_buff where upcall stats should be appended. + * + * Retrieves upcall stats for the given device. + * + * Must be called with ovs_mutex or rcu_read_lock. + */ +int ovs_vport_get_upcall_stats(struct vport *vport, struct sk_buff *skb) +{ + struct nlattr *nla; + int i; + + __u64 tx_success = 0; + __u64 tx_fail = 0; + + for_each_possible_cpu(i) { + const struct vport_upcall_stats_percpu *stats; + unsigned int start; + + stats = per_cpu_ptr(vport->upcall_stats, i); + do { + start = u64_stats_fetch_begin(&stats->syncp); + tx_success += u64_stats_read(&stats->n_success); + tx_fail += u64_stats_read(&stats->n_fail); + } while (u64_stats_fetch_retry(&stats->syncp, start)); + } + + nla = nla_nest_start_noflag(skb, OVS_VPORT_ATTR_UPCALL_STATS); + if (!nla) + return -EMSGSIZE; + + if (nla_put_u64_64bit(skb, OVS_VPORT_UPCALL_ATTR_SUCCESS, tx_success, + OVS_VPORT_ATTR_PAD)) { + nla_nest_cancel(skb, nla); + return -EMSGSIZE; + } + + if (nla_put_u64_64bit(skb, OVS_VPORT_UPCALL_ATTR_FAIL, tx_fail, + OVS_VPORT_ATTR_PAD)) { + nla_nest_cancel(skb, nla); + return -EMSGSIZE; + } + nla_nest_end(skb, nla); + + return 0; +} + +/** * ovs_vport_get_options - retrieve device options * * @vport: vport from which to retrieve the options. diff --git a/net/openvswitch/vport.h b/net/openvswitch/vport.h index 6ff45e8a0868..3e71ca8ad8a7 100644 --- a/net/openvswitch/vport.h +++ b/net/openvswitch/vport.h @@ -32,6 +32,8 @@ struct vport *ovs_vport_locate(const struct net *net, const char *name); void ovs_vport_get_stats(struct vport *, struct ovs_vport_stats *); +int ovs_vport_get_upcall_stats(struct vport *vport, struct sk_buff *skb); + int ovs_vport_set_options(struct vport *, struct nlattr *options); int ovs_vport_get_options(const struct vport *, struct sk_buff *); @@ -65,6 +67,7 @@ struct vport_portids { * @hash_node: Element in @dev_table hash table in vport.c. * @dp_hash_node: Element in @datapath->ports hash table in datapath.c. * @ops: Class structure. + * @upcall_stats: Upcall stats of every ports. * @detach_list: list used for detaching vport in net-exit call. * @rcu: RCU callback head for deferred destruction. */ @@ -78,6 +81,7 @@ struct vport { struct hlist_node hash_node; struct hlist_node dp_hash_node; const struct vport_ops *ops; + struct vport_upcall_stats_percpu __percpu *upcall_stats; struct list_head detach_list; struct rcu_head rcu; @@ -137,6 +141,18 @@ struct vport_ops { struct list_head list; }; +/** + * struct vport_upcall_stats_percpu - per-cpu packet upcall statistics for + * a given vport. + * @n_success: Number of packets that upcall to userspace succeed. + * @n_fail: Number of packets that upcall to userspace failed. + */ +struct vport_upcall_stats_percpu { + struct u64_stats_sync syncp; + u64_stats_t n_success; + u64_stats_t n_fail; +}; + struct vport *ovs_vport_alloc(int priv_size, const struct vport_ops *, const struct vport_parms *); void ovs_vport_free(struct vport *); diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index 6ce8dd19f33c..b5ab98ca2511 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -1350,7 +1350,7 @@ static bool fanout_flow_is_huge(struct packet_sock *po, struct sk_buff *skb) if (READ_ONCE(history[i]) == rxhash) count++; - victim = prandom_u32_max(ROLLOVER_HLEN); + victim = get_random_u32_below(ROLLOVER_HLEN); /* Avoid dirtying the cache line if possible */ if (READ_ONCE(history[victim]) != rxhash) @@ -1386,7 +1386,7 @@ static unsigned int fanout_demux_rnd(struct packet_fanout *f, struct sk_buff *skb, unsigned int num) { - return prandom_u32_max(num); + return get_random_u32_below(num); } static unsigned int fanout_demux_rollover(struct packet_fanout *f, @@ -1777,6 +1777,7 @@ static int fanout_add(struct sock *sk, struct fanout_args *args) match->prot_hook.af_packet_net = read_pnet(&match->net); match->prot_hook.id_match = match_fanout_group; match->max_num_members = args->max_num_members; + match->prot_hook.ignore_outgoing = type_flags & PACKET_FANOUT_FLAG_IGNORE_OUTGOING; list_add(&match->list, &fanout_list); } err = -EINVAL; @@ -2293,8 +2294,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, if (skb->ip_summed == CHECKSUM_PARTIAL) status |= TP_STATUS_CSUMNOTREADY; else if (skb->pkt_type != PACKET_OUTGOING && - (skb->ip_summed == CHECKSUM_COMPLETE || - skb_csum_unnecessary(skb))) + skb_csum_unnecessary(skb)) status |= TP_STATUS_CSUM_VALID; if (snaplen > res) @@ -3277,7 +3277,7 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr, int addr_len) { struct sock *sk = sock->sk; - char name[sizeof(uaddr->sa_data) + 1]; + char name[sizeof(uaddr->sa_data_min) + 1]; /* * Check legality @@ -3288,8 +3288,8 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr, /* uaddr->sa_data comes from the userspace, it's not guaranteed to be * zero-terminated. */ - memcpy(name, uaddr->sa_data, sizeof(uaddr->sa_data)); - name[sizeof(uaddr->sa_data)] = 0; + memcpy(name, uaddr->sa_data, sizeof(uaddr->sa_data_min)); + name[sizeof(uaddr->sa_data_min)] = 0; return packet_do_bind(sk, name, 0, pkt_sk(sk)->num); } @@ -3520,8 +3520,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, if (skb->ip_summed == CHECKSUM_PARTIAL) aux.tp_status |= TP_STATUS_CSUMNOTREADY; else if (skb->pkt_type != PACKET_OUTGOING && - (skb->ip_summed == CHECKSUM_COMPLETE || - skb_csum_unnecessary(skb))) + skb_csum_unnecessary(skb)) aux.tp_status |= TP_STATUS_CSUM_VALID; aux.tp_len = origlen; @@ -3561,11 +3560,11 @@ static int packet_getname_spkt(struct socket *sock, struct sockaddr *uaddr, return -EOPNOTSUPP; uaddr->sa_family = AF_PACKET; - memset(uaddr->sa_data, 0, sizeof(uaddr->sa_data)); + memset(uaddr->sa_data, 0, sizeof(uaddr->sa_data_min)); rcu_read_lock(); dev = dev_get_by_index_rcu(sock_net(sk), READ_ONCE(pkt_sk(sk)->ifindex)); if (dev) - strscpy(uaddr->sa_data, dev->name, sizeof(uaddr->sa_data)); + strscpy(uaddr->sa_data, dev->name, sizeof(uaddr->sa_data_min)); rcu_read_unlock(); return sizeof(*uaddr); diff --git a/net/qrtr/ns.c b/net/qrtr/ns.c index 1990d496fcfc..e595079c2caf 100644 --- a/net/qrtr/ns.c +++ b/net/qrtr/ns.c @@ -83,7 +83,10 @@ static struct qrtr_node *node_get(unsigned int node_id) node->id = node_id; - radix_tree_insert(&nodes, node_id, node); + if (radix_tree_insert(&nodes, node_id, node)) { + kfree(node); + return NULL; + } return node; } diff --git a/net/rds/message.c b/net/rds/message.c index 44dbc612ef54..b47e4f0a1639 100644 --- a/net/rds/message.c +++ b/net/rds/message.c @@ -366,7 +366,6 @@ static int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter * struct scatterlist *sg; int ret = 0; int length = iov_iter_count(from); - int total_copied = 0; struct rds_msg_zcopy_info *info; rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); @@ -404,7 +403,6 @@ static int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter * ret = -EFAULT; goto err; } - total_copied += copied; length -= copied; sg_set_page(sg, pages, copied, start); rm->data.op_nents++; diff --git a/net/rds/send.c b/net/rds/send.c index 0c5504068e3c..5e57a1581dc6 100644 --- a/net/rds/send.c +++ b/net/rds/send.c @@ -1114,7 +1114,7 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) struct rds_conn_path *cpath; struct in6_addr daddr; __u32 scope_id = 0; - size_t total_payload_len = payload_len, rdma_payload_len = 0; + size_t rdma_payload_len = 0; bool zcopy = ((msg->msg_flags & MSG_ZEROCOPY) && sock_flag(rds_rs_to_sk(rs), SOCK_ZEROCOPY)); int num_sgs = DIV_ROUND_UP(payload_len, PAGE_SIZE); @@ -1243,7 +1243,6 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) if (ret) goto out; - total_payload_len += rdma_payload_len; if (max_t(size_t, payload_len, rdma_payload_len) > RDS_MAX_MSG_SIZE) { ret = -EMSGSIZE; goto out; diff --git a/net/rds/tcp.c b/net/rds/tcp.c index 4444fd82b66d..c5b86066ff66 100644 --- a/net/rds/tcp.c +++ b/net/rds/tcp.c @@ -503,6 +503,9 @@ bool rds_tcp_tune(struct socket *sock) release_sock(sk); return false; } + /* Update ns_tracker to current stack trace and refcounted tracker */ + __netns_tracker_free(net, &sk->ns_tracker, false); + sk->sk_net_refcnt = 1; netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL); sock_inuse_add(net, 1); diff --git a/net/rfkill/core.c b/net/rfkill/core.c index dac4fdc7488a..b390ff245d5e 100644 --- a/net/rfkill/core.c +++ b/net/rfkill/core.c @@ -832,7 +832,7 @@ static void rfkill_release(struct device *dev) kfree(rfkill); } -static int rfkill_dev_uevent(struct device *dev, struct kobj_uevent_env *env) +static int rfkill_dev_uevent(const struct device *dev, struct kobj_uevent_env *env) { struct rfkill *rfkill = to_rfkill(dev); unsigned long flags; diff --git a/net/rose/af_rose.c b/net/rose/af_rose.c index 36fefc3957d7..ca2b17f32670 100644 --- a/net/rose/af_rose.c +++ b/net/rose/af_rose.c @@ -488,6 +488,12 @@ static int rose_listen(struct socket *sock, int backlog) { struct sock *sk = sock->sk; + lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + release_sock(sk); + return -EINVAL; + } + if (sk->sk_state != TCP_LISTEN) { struct rose_sock *rose = rose_sk(sk); @@ -497,8 +503,10 @@ static int rose_listen(struct socket *sock, int backlog) memset(rose->dest_digis, 0, AX25_ADDR_LEN * ROSE_MAX_DIGIS); sk->sk_max_ack_backlog = backlog; sk->sk_state = TCP_LISTEN; + release_sock(sk); return 0; } + release_sock(sk); return -EOPNOTSUPP; } diff --git a/net/rxrpc/Kconfig b/net/rxrpc/Kconfig index accd35c05577..7ae023b37a83 100644 --- a/net/rxrpc/Kconfig +++ b/net/rxrpc/Kconfig @@ -58,4 +58,11 @@ config RXKAD See Documentation/networking/rxrpc.rst. +config RXPERF + tristate "RxRPC test service" + help + Provide an rxperf service tester. This listens on UDP port 7009 for + incoming calls from the rxperf program (an example of which can be + found in OpenAFS). + endif diff --git a/net/rxrpc/Makefile b/net/rxrpc/Makefile index b11281bed2a4..ac5caf5a48e1 100644 --- a/net/rxrpc/Makefile +++ b/net/rxrpc/Makefile @@ -10,12 +10,14 @@ rxrpc-y := \ call_accept.o \ call_event.o \ call_object.o \ + call_state.o \ conn_client.o \ conn_event.o \ conn_object.o \ conn_service.o \ input.o \ insecure.o \ + io_thread.o \ key.o \ local_event.o \ local_object.o \ @@ -30,8 +32,12 @@ rxrpc-y := \ sendmsg.o \ server_key.o \ skbuff.o \ + txbuf.o \ utils.o rxrpc-$(CONFIG_PROC_FS) += proc.o rxrpc-$(CONFIG_RXKAD) += rxkad.o rxrpc-$(CONFIG_SYSCTL) += sysctl.o + + +obj-$(CONFIG_RXPERF) += rxperf.o diff --git a/net/rxrpc/af_rxrpc.c b/net/rxrpc/af_rxrpc.c index ceba28e9dce6..ebbd4a1c3f86 100644 --- a/net/rxrpc/af_rxrpc.c +++ b/net/rxrpc/af_rxrpc.c @@ -39,7 +39,7 @@ atomic_t rxrpc_debug_id; EXPORT_SYMBOL(rxrpc_debug_id); /* count of skbs currently in use */ -atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs; +atomic_t rxrpc_n_rx_skbs; struct workqueue_struct *rxrpc_workqueue; @@ -93,12 +93,11 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx, srx->transport_len > len) return -EINVAL; - if (srx->transport.family != rx->family && - srx->transport.family == AF_INET && rx->family != AF_INET6) - return -EAFNOSUPPORT; - switch (srx->transport.family) { case AF_INET: + if (rx->family != AF_INET && + rx->family != AF_INET6) + return -EAFNOSUPPORT; if (srx->transport_len < sizeof(struct sockaddr_in)) return -EINVAL; tail = offsetof(struct sockaddr_rxrpc, transport.sin.__pad); @@ -106,6 +105,8 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx, #ifdef CONFIG_AF_RXRPC_IPV6 case AF_INET6: + if (rx->family != AF_INET6) + return -EAFNOSUPPORT; if (srx->transport_len < sizeof(struct sockaddr_in6)) return -EINVAL; tail = offsetof(struct sockaddr_rxrpc, transport) + @@ -154,10 +155,10 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) if (service_id) { write_lock(&local->services_lock); - if (rcu_access_pointer(local->service)) + if (local->service) goto service_in_use; rx->local = local; - rcu_assign_pointer(local->service, rx); + local->service = rx; write_unlock(&local->services_lock); rx->sk.sk_state = RXRPC_SERVER_BOUND; @@ -193,8 +194,8 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) service_in_use: write_unlock(&local->services_lock); - rxrpc_unuse_local(local); - rxrpc_put_local(local); + rxrpc_unuse_local(local, rxrpc_local_unuse_bind); + rxrpc_put_local(local, rxrpc_local_put_bind); ret = -EADDRINUSE; error_unlock: release_sock(&rx->sk); @@ -327,7 +328,6 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, mutex_unlock(&call->user_mutex); } - rxrpc_put_peer(cp.peer); _leave(" = %p", call); return call; } @@ -358,9 +358,9 @@ void rxrpc_kernel_end_call(struct socket *sock, struct rxrpc_call *call) /* Make sure we're not going to call back into a kernel service */ if (call->notify_rx) { - spin_lock_bh(&call->notify_lock); + spin_lock(&call->notify_lock); call->notify_rx = rxrpc_dummy_notify_rx; - spin_unlock_bh(&call->notify_lock); + spin_unlock(&call->notify_lock); } mutex_unlock(&call->user_mutex); @@ -373,13 +373,17 @@ EXPORT_SYMBOL(rxrpc_kernel_end_call); * @sock: The socket the call is on * @call: The call to check * - * Allow a kernel service to find out whether a call is still alive - - * ie. whether it has completed. + * Allow a kernel service to find out whether a call is still alive - whether + * it has completed successfully and all received data has been consumed. */ bool rxrpc_kernel_check_life(const struct socket *sock, const struct rxrpc_call *call) { - return call->state != RXRPC_CALL_COMPLETE; + if (!rxrpc_call_is_complete(call)) + return true; + if (call->completion != RXRPC_CALL_SUCCEEDED) + return false; + return !skb_queue_empty(&call->recvmsg_queue); } EXPORT_SYMBOL(rxrpc_kernel_check_life); @@ -811,14 +815,12 @@ static int rxrpc_shutdown(struct socket *sock, int flags) lock_sock(sk); - spin_lock_bh(&sk->sk_receive_queue.lock); if (sk->sk_state < RXRPC_CLOSE) { sk->sk_state = RXRPC_CLOSE; sk->sk_shutdown = SHUTDOWN_MASK; } else { ret = -ESHUTDOWN; } - spin_unlock_bh(&sk->sk_receive_queue.lock); rxrpc_discard_prealloc(rx); @@ -871,13 +873,11 @@ static int rxrpc_release_sock(struct sock *sk) break; } - spin_lock_bh(&sk->sk_receive_queue.lock); sk->sk_state = RXRPC_CLOSE; - spin_unlock_bh(&sk->sk_receive_queue.lock); - if (rx->local && rcu_access_pointer(rx->local->service) == rx) { + if (rx->local && rx->local->service == rx) { write_lock(&rx->local->services_lock); - rcu_assign_pointer(rx->local->service, NULL); + rx->local->service = NULL; write_unlock(&rx->local->services_lock); } @@ -887,8 +887,8 @@ static int rxrpc_release_sock(struct sock *sk) flush_workqueue(rxrpc_workqueue); rxrpc_purge_queue(&sk->sk_receive_queue); - rxrpc_unuse_local(rx->local); - rxrpc_put_local(rx->local); + rxrpc_unuse_local(rx->local, rxrpc_local_unuse_release_sock); + rxrpc_put_local(rx->local, rxrpc_local_put_release_sock); rx->local = NULL; key_put(rx->key); rx->key = NULL; @@ -960,16 +960,9 @@ static const struct net_proto_family rxrpc_family_ops = { static int __init af_rxrpc_init(void) { int ret = -1; - unsigned int tmp; BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > sizeof_field(struct sk_buff, cb)); - get_random_bytes(&tmp, sizeof(tmp)); - tmp &= 0x3fffffff; - if (tmp == 0) - tmp = 1; - idr_set_cursor(&rxrpc_client_conn_ids, tmp); - ret = -ENOMEM; rxrpc_call_jar = kmem_cache_create( "rxrpc_call_jar", sizeof(struct rxrpc_call), 0, @@ -979,7 +972,7 @@ static int __init af_rxrpc_init(void) goto error_call_jar; } - rxrpc_workqueue = alloc_workqueue("krxrpcd", 0, 1); + rxrpc_workqueue = alloc_workqueue("krxrpcd", WQ_HIGHPRI | WQ_MEM_RECLAIM | WQ_UNBOUND, 1); if (!rxrpc_workqueue) { pr_notice("Failed to allocate work queue\n"); goto error_work_queue; @@ -1059,14 +1052,12 @@ static void __exit af_rxrpc_exit(void) sock_unregister(PF_RXRPC); proto_unregister(&rxrpc_proto); unregister_pernet_device(&rxrpc_net_ops); - ASSERTCMP(atomic_read(&rxrpc_n_tx_skbs), ==, 0); ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0); /* Make sure the local and peer records pinned by any dying connections * are released. */ rcu_barrier(); - rxrpc_destroy_client_conn_ids(); destroy_workqueue(rxrpc_workqueue); rxrpc_exit_security(); diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h index 1ad0ec5afb50..433060cade03 100644 --- a/net/rxrpc/ar-internal.h +++ b/net/rxrpc/ar-internal.h @@ -29,12 +29,16 @@ struct rxrpc_crypt { struct key_preparsed_payload; struct rxrpc_connection; +struct rxrpc_txbuf; /* * Mark applied to socket buffers in skb->mark. skb->priority is used * to pass supplementary information. */ enum rxrpc_skb_mark { + RXRPC_SKB_MARK_PACKET, /* Received packet */ + RXRPC_SKB_MARK_ERROR, /* Error notification */ + RXRPC_SKB_MARK_SERVICE_CONN_SECURED, /* Service connection response has been verified */ RXRPC_SKB_MARK_REJECT_BUSY, /* Reject with BUSY */ RXRPC_SKB_MARK_REJECT_ABORT, /* Reject with ABORT (code in skb->priority) */ }; @@ -72,13 +76,7 @@ struct rxrpc_net { bool live; - bool kill_all_client_conns; atomic_t nr_client_conns; - spinlock_t client_conn_cache_lock; /* Lock for ->*_client_conns */ - spinlock_t client_conn_discard_lock; /* Prevent multiple discarders */ - struct list_head idle_client_conns; - struct work_struct client_conn_reaper; - struct timer_list client_conn_reap_timer; struct hlist_head local_endpoints; struct mutex local_mutex; /* Lock for ->local_endpoints */ @@ -93,6 +91,27 @@ struct rxrpc_net { struct list_head peer_keepalive_new; struct timer_list peer_keepalive_timer; struct work_struct peer_keepalive_work; + + atomic_t stat_tx_data; + atomic_t stat_tx_data_retrans; + atomic_t stat_tx_data_send; + atomic_t stat_tx_data_send_frag; + atomic_t stat_tx_data_send_fail; + atomic_t stat_tx_data_underflow; + atomic_t stat_tx_data_cwnd_reset; + atomic_t stat_rx_data; + atomic_t stat_rx_data_reqack; + atomic_t stat_rx_data_jumbo; + + atomic_t stat_tx_ack_fill; + atomic_t stat_tx_ack_send; + atomic_t stat_tx_ack_skip; + atomic_t stat_tx_acks[256]; + atomic_t stat_rx_acks[256]; + + atomic_t stat_why_req_ack[8]; + + atomic_t stat_io_loop; }; /* @@ -178,20 +197,13 @@ struct rxrpc_host_header { * - max 48 bytes (struct sk_buff::cb) */ struct rxrpc_skb_priv { - atomic_t nr_ring_pins; /* Number of rxtx ring pins */ - u8 nr_subpackets; /* Number of subpackets */ - u8 rx_flags; /* Received packet flags */ -#define RXRPC_SKB_INCL_LAST 0x01 /* - Includes last packet */ -#define RXRPC_SKB_TX_BUFFER 0x02 /* - Is transmit buffer */ - union { - int remain; /* amount of space remaining for next write */ + struct rxrpc_connection *conn; /* Connection referred to (poke packet) */ + u16 offset; /* Offset of data */ + u16 len; /* Length of data */ + u8 flags; +#define RXRPC_RX_VERIFIED 0x01 - /* List of requested ACKs on subpackets */ - unsigned long rx_req_ack[(RXRPC_MAX_NR_JUMBO + BITS_PER_LONG - 1) / - BITS_PER_LONG]; - }; - - struct rxrpc_host_header hdr; /* RxRPC packet header from this packet */ + struct rxrpc_host_header hdr; /* RxRPC packet header from this packet */ }; #define rxrpc_skb(__skb) ((struct rxrpc_skb_priv *) &(__skb)->cb) @@ -233,31 +245,24 @@ struct rxrpc_security { size_t *, size_t *, size_t *); /* impose security on a packet */ - int (*secure_packet)(struct rxrpc_call *, struct sk_buff *, size_t); + int (*secure_packet)(struct rxrpc_call *, struct rxrpc_txbuf *); /* verify the security on a received packet */ - int (*verify_packet)(struct rxrpc_call *, struct sk_buff *, - unsigned int, unsigned int, rxrpc_seq_t, u16); + int (*verify_packet)(struct rxrpc_call *, struct sk_buff *); /* Free crypto request on a call */ void (*free_call_crypto)(struct rxrpc_call *); - /* Locate the data in a received packet that has been verified. */ - void (*locate_data)(struct rxrpc_call *, struct sk_buff *, - unsigned int *, unsigned int *); - /* issue a challenge */ int (*issue_challenge)(struct rxrpc_connection *); /* respond to a challenge */ int (*respond_to_challenge)(struct rxrpc_connection *, - struct sk_buff *, - u32 *); + struct sk_buff *); /* verify a response */ int (*verify_response)(struct rxrpc_connection *, - struct sk_buff *, - u32 *); + struct sk_buff *); /* clear connection security */ void (*clear)(struct rxrpc_connection *); @@ -272,21 +277,34 @@ struct rxrpc_local { struct rcu_head rcu; atomic_t active_users; /* Number of users of the local endpoint */ refcount_t ref; /* Number of references to the structure */ - struct rxrpc_net *rxnet; /* The network ns in which this resides */ + struct net *net; /* The network namespace */ + struct rxrpc_net *rxnet; /* Our bits in the network namespace */ struct hlist_node link; struct socket *socket; /* my UDP socket */ - struct work_struct processor; - struct rxrpc_sock __rcu *service; /* Service(s) listening on this endpoint */ + struct task_struct *io_thread; + struct completion io_thread_ready; /* Indication that the I/O thread started */ + struct rxrpc_sock *service; /* Service(s) listening on this endpoint */ struct rw_semaphore defrag_sem; /* control re-enablement of IP DF bit */ - struct sk_buff_head reject_queue; /* packets awaiting rejection */ - struct sk_buff_head event_queue; /* endpoint event packets awaiting processing */ + struct sk_buff_head rx_queue; /* Received packets */ + struct list_head conn_attend_q; /* Conns requiring immediate attention */ + struct list_head call_attend_q; /* Calls requiring immediate attention */ + struct rb_root client_bundles; /* Client connection bundles by socket params */ spinlock_t client_bundles_lock; /* Lock for client_bundles */ + bool kill_all_client_conns; + struct list_head idle_client_conns; + struct timer_list client_conn_reap_timer; + unsigned long client_conn_flags; +#define RXRPC_CLIENT_CONN_REAP_TIMER 0 /* The client conn reap timer expired */ + spinlock_t lock; /* access lock */ rwlock_t services_lock; /* lock for services list */ int debug_id; /* debug ID for printks */ bool dead; bool service_closed; /* Service socket closed */ + struct idr conn_ids; /* List of connection IDs */ + struct list_head new_client_calls; /* Newly created client calls need connection */ + spinlock_t client_call_lock; /* Lock for ->new_client_calls */ struct sockaddr_rxrpc srx; /* local address */ }; @@ -326,7 +344,7 @@ struct rxrpc_peer { u32 rto_j; /* Retransmission timeout in jiffies */ u8 backoff; /* Backoff timeout */ - u8 cong_cwnd; /* Congestion window size */ + u8 cong_ssthresh; /* Congestion slow-start threshold */ }; /* @@ -344,7 +362,6 @@ struct rxrpc_conn_proto { struct rxrpc_conn_parameters { struct rxrpc_local *local; /* Representation of local endpoint */ - struct rxrpc_peer *peer; /* Remote endpoint */ struct key *key; /* Security details */ bool exclusive; /* T if conn is exclusive */ bool upgrade; /* T if service ID can be upgraded */ @@ -353,10 +370,21 @@ struct rxrpc_conn_parameters { }; /* + * Call completion condition (state == RXRPC_CALL_COMPLETE). + */ +enum rxrpc_call_completion { + RXRPC_CALL_SUCCEEDED, /* - Normal termination */ + RXRPC_CALL_REMOTELY_ABORTED, /* - call aborted by peer */ + RXRPC_CALL_LOCALLY_ABORTED, /* - call aborted locally on error or close */ + RXRPC_CALL_LOCAL_ERROR, /* - call failed due to local error */ + RXRPC_CALL_NETWORK_ERROR, /* - call terminated by network error */ + NR__RXRPC_CALL_COMPLETIONS +}; + +/* * Bits in the connection flags. */ enum rxrpc_conn_flag { - RXRPC_CONN_HAS_IDR, /* Has a client conn ID assigned */ RXRPC_CONN_IN_SERVICE_CONNS, /* Conn is in peer->service_conns */ RXRPC_CONN_DONT_REUSE, /* Don't reuse this connection */ RXRPC_CONN_PROBING_FOR_UPGRADE, /* Probing for service upgrade */ @@ -376,6 +404,7 @@ enum rxrpc_conn_flag { */ enum rxrpc_conn_event { RXRPC_CONN_EV_CHALLENGE, /* Send challenge packet */ + RXRPC_CONN_EV_ABORT_CALLS, /* Abort attached calls */ }; /* @@ -383,13 +412,13 @@ enum rxrpc_conn_event { */ enum rxrpc_conn_proto_state { RXRPC_CONN_UNUSED, /* Connection not yet attempted */ + RXRPC_CONN_CLIENT_UNSECURED, /* Client connection needs security init */ RXRPC_CONN_CLIENT, /* Client connection */ RXRPC_CONN_SERVICE_PREALLOC, /* Service connection preallocation */ RXRPC_CONN_SERVICE_UNSECURED, /* Service unsecured connection */ RXRPC_CONN_SERVICE_CHALLENGING, /* Service challenging for security */ RXRPC_CONN_SERVICE, /* Service secured connection */ - RXRPC_CONN_REMOTELY_ABORTED, /* Conn aborted by peer */ - RXRPC_CONN_LOCALLY_ABORTED, /* Conn aborted locally */ + RXRPC_CONN_ABORTED, /* Conn aborted */ RXRPC_CONN__NR_STATES }; @@ -397,13 +426,19 @@ enum rxrpc_conn_proto_state { * RxRPC client connection bundle. */ struct rxrpc_bundle { - struct rxrpc_conn_parameters params; + struct rxrpc_local *local; /* Representation of local endpoint */ + struct rxrpc_peer *peer; /* Remote endpoint */ + struct key *key; /* Security details */ + const struct rxrpc_security *security; /* applied security module */ refcount_t ref; + atomic_t active; /* Number of active users */ unsigned int debug_id; + u32 security_level; /* Security level selected */ + u16 service_id; /* Service ID for this connection */ bool try_upgrade; /* True if the bundle is attempting upgrade */ - bool alloc_conn; /* True if someone's getting a conn */ - short alloc_error; /* Error from last conn allocation */ - spinlock_t channel_lock; + bool exclusive; /* T if conn is exclusive */ + bool upgrade; /* T if service ID can be upgraded */ + unsigned short alloc_error; /* Error from last conn allocation */ struct rb_node local_node; /* Node in local->client_conns */ struct list_head waiting_calls; /* Calls waiting for channels */ unsigned long avail_chans; /* Mask of available channels */ @@ -417,16 +452,21 @@ struct rxrpc_bundle { */ struct rxrpc_connection { struct rxrpc_conn_proto proto; - struct rxrpc_conn_parameters params; + struct rxrpc_local *local; /* Representation of local endpoint */ + struct rxrpc_peer *peer; /* Remote endpoint */ + struct rxrpc_net *rxnet; /* Network namespace to which call belongs */ + struct key *key; /* Security details */ + struct list_head attend_link; /* Link in local->conn_attend_q */ refcount_t ref; + atomic_t active; /* Active count for service conns */ struct rcu_head rcu; struct list_head cache_link; unsigned char act_chans; /* Mask of active channels */ struct rxrpc_channel { unsigned long final_ack_at; /* Time at which to issue final ACK */ - struct rxrpc_call __rcu *call; /* Active call */ + struct rxrpc_call *call; /* Active call */ unsigned int call_debug_id; /* call->debug_id */ u32 call_id; /* ID of current call */ u32 call_counter; /* Call ID counter */ @@ -440,12 +480,14 @@ struct rxrpc_connection { struct timer_list timer; /* Conn event timer */ struct work_struct processor; /* connection event processor */ + struct work_struct destructor; /* In-process-context destroyer */ struct rxrpc_bundle *bundle; /* Client connection bundle */ struct rb_node service_node; /* Node in peer->service_conns */ struct list_head proc_link; /* link in procfs list */ struct list_head link; /* link in master connection list */ struct sk_buff_head rx_queue; /* received conn-level packets */ + struct mutex security_lock; /* Lock for security management */ const struct rxrpc_security *security; /* applied security module */ union { struct { @@ -459,14 +501,19 @@ struct rxrpc_connection { unsigned long idle_timestamp; /* Time at which last became idle */ spinlock_t state_lock; /* state-change lock */ enum rxrpc_conn_proto_state state; /* current state of connection */ - u32 abort_code; /* Abort code of connection abort */ + enum rxrpc_call_completion completion; /* Completion condition */ + s32 abort_code; /* Abort code of connection abort */ int debug_id; /* debug ID for printks */ atomic_t serial; /* packet serial number counter */ unsigned int hi_serial; /* highest serial number received */ u32 service_id; /* Service ID, possibly upgraded */ + u32 security_level; /* Security level selected */ u8 security_ix; /* security type */ u8 out_clientflag; /* RXRPC_CLIENT_INITIATED if we are client */ u8 bundle_shift; /* Index into bundle->avail_chans */ + bool exclusive; /* T if conn is exclusive */ + bool upgrade; /* T if service ID can be upgraded */ + u16 orig_service_id; /* Originally requested service ID */ short error; /* Local error code */ }; @@ -490,26 +537,25 @@ enum rxrpc_call_flag { RXRPC_CALL_EXPOSED, /* The call was exposed to the world */ RXRPC_CALL_RX_LAST, /* Received the last packet (at rxtx_top) */ RXRPC_CALL_TX_LAST, /* Last packet in Tx buffer (at rxtx_top) */ + RXRPC_CALL_TX_ALL_ACKED, /* Last packet has been hard-acked */ RXRPC_CALL_SEND_PING, /* A ping will need to be sent */ RXRPC_CALL_RETRANS_TIMEOUT, /* Retransmission due to timeout occurred */ RXRPC_CALL_BEGAN_RX_TIMER, /* We began the expect_rx_by timer */ RXRPC_CALL_RX_HEARD, /* The peer responded at least once to this call */ - RXRPC_CALL_RX_UNDERRUN, /* Got data underrun */ RXRPC_CALL_DISCONNECTED, /* The call has been disconnected */ RXRPC_CALL_KERNEL, /* The call was made by the kernel */ RXRPC_CALL_UPGRADE, /* Service upgrade was requested for the call */ + RXRPC_CALL_EXCLUSIVE, /* The call uses a once-only connection */ + RXRPC_CALL_RX_IS_IDLE, /* recvmsg() is idle - send an ACK */ + RXRPC_CALL_RECVMSG_READ_ALL, /* recvmsg() read all of the received data */ }; /* * Events that can be raised on a call. */ enum rxrpc_call_event { - RXRPC_CALL_EV_ACK, /* need to generate ACK */ - RXRPC_CALL_EV_ABORT, /* need to generate abort */ - RXRPC_CALL_EV_RESEND, /* Tx resend required */ - RXRPC_CALL_EV_PING, /* Ping send required */ - RXRPC_CALL_EV_EXPIRED, /* Expiry occurred */ RXRPC_CALL_EV_ACK_LOST, /* ACK may be lost, send ping */ + RXRPC_CALL_EV_INITIAL_PING, /* Send initial ping for a new service call */ }; /* @@ -532,18 +578,6 @@ enum rxrpc_call_state { }; /* - * Call completion condition (state == RXRPC_CALL_COMPLETE). - */ -enum rxrpc_call_completion { - RXRPC_CALL_SUCCEEDED, /* - Normal termination */ - RXRPC_CALL_REMOTELY_ABORTED, /* - call aborted by peer */ - RXRPC_CALL_LOCALLY_ABORTED, /* - call aborted locally on error or close */ - RXRPC_CALL_LOCAL_ERROR, /* - call failed due to local error */ - RXRPC_CALL_NETWORK_ERROR, /* - call terminated by network error */ - NR__RXRPC_CALL_COMPLETIONS -}; - -/* * Call Tx congestion management modes. */ enum rxrpc_congest_mode { @@ -561,12 +595,16 @@ enum rxrpc_congest_mode { struct rxrpc_call { struct rcu_head rcu; struct rxrpc_connection *conn; /* connection carrying call */ + struct rxrpc_bundle *bundle; /* Connection bundle to use */ struct rxrpc_peer *peer; /* Peer record for remote address */ + struct rxrpc_local *local; /* Representation of local endpoint */ struct rxrpc_sock __rcu *socket; /* socket responsible */ struct rxrpc_net *rxnet; /* Network namespace to which call belongs */ + struct key *key; /* Security details */ const struct rxrpc_security *security; /* applied security module */ struct mutex user_mutex; /* User access mutex */ - unsigned long ack_at; /* When deferred ACK needs to happen */ + struct sockaddr_rxrpc dest_srx; /* Destination address */ + unsigned long delay_ack_at; /* When DELAY ACK needs to happen */ unsigned long ack_lost_at; /* When ACK is figured as lost */ unsigned long resend_at; /* When next resend needs to happen */ unsigned long ping_at; /* When next to send a ping */ @@ -576,77 +614,70 @@ struct rxrpc_call { unsigned long expect_term_by; /* When we expect call termination by */ u32 next_rx_timo; /* Timeout for next Rx packet (jif) */ u32 next_req_timo; /* Timeout for next Rx request packet (jif) */ - struct skcipher_request *cipher_req; /* Packet cipher request buffer */ struct timer_list timer; /* Combined event timer */ - struct work_struct processor; /* Event processor */ + struct work_struct destroyer; /* In-process-context destroyer */ rxrpc_notify_rx_t notify_rx; /* kernel service Rx notification function */ struct list_head link; /* link in master call list */ - struct list_head chan_wait_link; /* Link in conn->bundle->waiting_calls */ + struct list_head wait_link; /* Link in local->new_client_calls */ struct hlist_node error_link; /* link in error distribution list */ struct list_head accept_link; /* Link in rx->acceptq */ struct list_head recvmsg_link; /* Link in rx->recvmsg_q */ struct list_head sock_link; /* Link in rx->sock_calls */ struct rb_node sock_node; /* Node in rx->calls */ - struct sk_buff *tx_pending; /* Tx socket buffer being filled */ + struct list_head attend_link; /* Link in local->call_attend_q */ + struct rxrpc_txbuf *tx_pending; /* Tx buffer being filled */ wait_queue_head_t waitq; /* Wait queue for channel or Tx */ s64 tx_total_len; /* Total length left to be transmitted (or -1) */ - __be32 crypto_buf[2]; /* Temporary packet crypto buffer */ unsigned long user_call_ID; /* user-defined call ID */ unsigned long flags; unsigned long events; - spinlock_t lock; spinlock_t notify_lock; /* Kernel notification lock */ - rwlock_t state_lock; /* lock for state transition */ - u32 abort_code; /* Local/remote abort code */ + unsigned int send_abort_why; /* Why the abort [enum rxrpc_abort_reason] */ + s32 send_abort; /* Abort code to be sent */ + short send_abort_err; /* Error to be associated with the abort */ + rxrpc_seq_t send_abort_seq; /* DATA packet that incurred the abort (or 0) */ + s32 abort_code; /* Local/remote abort code */ int error; /* Local error incurred */ - enum rxrpc_call_state state; /* current state of call */ + enum rxrpc_call_state _state; /* Current state of call (needs barrier) */ enum rxrpc_call_completion completion; /* Call completion condition */ refcount_t ref; - u16 service_id; /* service ID */ u8 security_ix; /* Security type */ enum rxrpc_interruptibility interruptibility; /* At what point call may be interrupted */ u32 call_id; /* call ID on connection */ u32 cid; /* connection ID plus channel index */ + u32 security_level; /* Security level selected */ int debug_id; /* debug ID for printks */ unsigned short rx_pkt_offset; /* Current recvmsg packet offset */ unsigned short rx_pkt_len; /* Current recvmsg packet len */ - bool rx_pkt_last; /* Current recvmsg packet is last */ - - /* Rx/Tx circular buffer, depending on phase. - * - * In the Rx phase, packets are annotated with 0 or the number of the - * segment of a jumbo packet each buffer refers to. There can be up to - * 47 segments in a maximum-size UDP packet. - * - * In the Tx phase, packets are annotated with which buffers have been - * acked. - */ -#define RXRPC_RXTX_BUFF_SIZE 64 -#define RXRPC_RXTX_BUFF_MASK (RXRPC_RXTX_BUFF_SIZE - 1) -#define RXRPC_INIT_RX_WINDOW_SIZE 63 - struct sk_buff **rxtx_buffer; - u8 *rxtx_annotations; -#define RXRPC_TX_ANNO_ACK 0 -#define RXRPC_TX_ANNO_UNACK 1 -#define RXRPC_TX_ANNO_NAK 2 -#define RXRPC_TX_ANNO_RETRANS 3 -#define RXRPC_TX_ANNO_MASK 0x03 -#define RXRPC_TX_ANNO_LAST 0x04 -#define RXRPC_TX_ANNO_RESENT 0x08 - -#define RXRPC_RX_ANNO_SUBPACKET 0x3f /* Subpacket number in jumbogram */ -#define RXRPC_RX_ANNO_VERIFIED 0x80 /* Set if verified and decrypted */ - rxrpc_seq_t tx_hard_ack; /* Dead slot in buffer; the first transmitted but - * not hard-ACK'd packet follows this. - */ + + /* Transmitted data tracking. */ + spinlock_t tx_lock; /* Transmit queue lock */ + struct list_head tx_sendmsg; /* Sendmsg prepared packets */ + struct list_head tx_buffer; /* Buffer of transmissible packets */ + rxrpc_seq_t tx_bottom; /* First packet in buffer */ + rxrpc_seq_t tx_transmitted; /* Highest packet transmitted */ + rxrpc_seq_t tx_prepared; /* Highest Tx slot prepared. */ rxrpc_seq_t tx_top; /* Highest Tx slot allocated. */ u16 tx_backoff; /* Delay to insert due to Tx failure */ + u8 tx_winsize; /* Maximum size of Tx window */ +#define RXRPC_TX_MAX_WINDOW 128 + ktime_t tx_last_sent; /* Last time a transmission occurred */ + + /* Received data tracking */ + struct sk_buff_head recvmsg_queue; /* Queue of packets ready for recvmsg() */ + struct sk_buff_head rx_oos_queue; /* Queue of out of sequence packets */ + + rxrpc_seq_t rx_highest_seq; /* Higest sequence number received */ + rxrpc_seq_t rx_consumed; /* Highest packet consumed */ + rxrpc_serial_t rx_serial; /* Highest serial received for this call */ + u8 rx_winsize; /* Size of Rx window */ /* TCP-style slow-start congestion control [RFC5681]. Since the SMSS * is fixed, we keep these numbers in terms of segments (ie. DATA * packets) rather than bytes. */ #define RXRPC_TX_SMSS RXRPC_JUMBO_DATALEN +#define RXRPC_MIN_CWND (RXRPC_TX_SMSS > 2190 ? 2 : RXRPC_TX_SMSS > 1095 ? 3 : 4) u8 cong_cwnd; /* Congestion window size */ u8 cong_extra; /* Extra to send for congestion management */ u8 cong_ssthresh; /* Slow-start threshold */ @@ -655,25 +686,17 @@ struct rxrpc_call { u8 cong_cumul_acks; /* Cumulative ACK count */ ktime_t cong_tstamp; /* Last time cwnd was changed */ - rxrpc_seq_t rx_hard_ack; /* Dead slot in buffer; the first received but not - * consumed packet follows this. - */ - rxrpc_seq_t rx_top; /* Highest Rx slot allocated. */ - rxrpc_seq_t rx_expect_next; /* Expected next packet sequence number */ - rxrpc_serial_t rx_serial; /* Highest serial received for this call */ - u8 rx_winsize; /* Size of Rx window */ - u8 tx_winsize; /* Maximum size of Tx window */ - bool tx_phase; /* T if transmission phase, F if receive phase */ - u8 nr_jumbo_bad; /* Number of jumbo dups/exceeds-windows */ - - spinlock_t input_lock; /* Lock for packet input to this call */ - /* Receive-phase ACK management (ACKs we send). */ u8 ackr_reason; /* reason to ACK */ rxrpc_serial_t ackr_serial; /* serial of packet being ACK'd */ - rxrpc_seq_t ackr_highest_seq; /* Higest sequence number received */ + atomic64_t ackr_window; /* Base (in LSW) and top (in MSW) of SACK window */ atomic_t ackr_nr_unacked; /* Number of unacked packets */ atomic_t ackr_nr_consumed; /* Number of packets needing hard ACK */ + struct { +#define RXRPC_SACK_SIZE 256 + /* SACK table for soft-acked packets */ + u8 ackr_sack_table[RXRPC_SACK_SIZE]; + } __aligned(8); /* RTT management */ rxrpc_serial_t rtt_serial[4]; /* Serial number of DATA or PING sent */ @@ -687,21 +710,20 @@ struct rxrpc_call { ktime_t acks_latest_ts; /* Timestamp of latest ACK received */ rxrpc_seq_t acks_first_seq; /* first sequence number received */ rxrpc_seq_t acks_prev_seq; /* Highest previousPacket received */ + rxrpc_seq_t acks_hard_ack; /* Latest hard-ack point */ rxrpc_seq_t acks_lowest_nak; /* Lowest NACK in the buffer (or ==tx_hard_ack) */ - rxrpc_seq_t acks_lost_top; /* tx_top at the time lost-ack ping sent */ - rxrpc_serial_t acks_lost_ping; /* Serial number of probe ACK */ + rxrpc_serial_t acks_highest_serial; /* Highest serial number ACK'd */ }; /* * Summary of a new ACK and the changes it made to the Tx buffer packet states. */ struct rxrpc_ack_summary { + u16 nr_acks; /* Number of ACKs in packet */ + u16 nr_new_acks; /* Number of new ACKs in packet */ + u16 nr_rot_new_acks; /* Number of rotated new ACKs */ u8 ack_reason; - u8 nr_acks; /* Number of ACKs in packet */ - u8 nr_nacks; /* Number of NACKs in packet */ - u8 nr_new_acks; /* Number of new ACKs in packet */ - u8 nr_new_nacks; /* Number of new NACKs in packet */ - u8 nr_rot_new_acks; /* Number of rotated new ACKs */ + bool saw_nacks; /* Saw NACKs in packet */ bool new_low_nack; /* T if new low NACK found */ bool retrans_timeo; /* T if reTx due to timeout happened */ u8 flight_size; /* Number of unreceived transmissions */ @@ -744,12 +766,57 @@ struct rxrpc_send_params { bool upgrade; /* If the connection is upgradeable */ }; +/* + * Buffer of data to be output as a packet. + */ +struct rxrpc_txbuf { + struct rcu_head rcu; + struct list_head call_link; /* Link in call->tx_sendmsg/tx_buffer */ + struct list_head tx_link; /* Link in live Enc queue or Tx queue */ + ktime_t last_sent; /* Time at which last transmitted */ + refcount_t ref; + rxrpc_seq_t seq; /* Sequence number of this packet */ + unsigned int call_debug_id; + unsigned int debug_id; + unsigned int len; /* Amount of data in buffer */ + unsigned int space; /* Remaining data space */ + unsigned int offset; /* Offset of fill point */ + unsigned long flags; +#define RXRPC_TXBUF_LAST 0 /* Set if last packet in Tx phase */ +#define RXRPC_TXBUF_RESENT 1 /* Set if has been resent */ + u8 /*enum rxrpc_propose_ack_trace*/ ack_why; /* If ack, why */ + struct { + /* The packet for encrypting and DMA'ing. We align it such + * that data[] aligns correctly for any crypto blocksize. + */ + u8 pad[64 - sizeof(struct rxrpc_wire_header)]; + struct rxrpc_wire_header wire; /* Network-ready header */ + union { + u8 data[RXRPC_JUMBO_DATALEN]; /* Data packet */ + struct { + struct rxrpc_ackpacket ack; + u8 acks[0]; + }; + }; + } __aligned(64); +}; + +static inline bool rxrpc_sending_to_server(const struct rxrpc_txbuf *txb) +{ + return txb->wire.flags & RXRPC_CLIENT_INITIATED; +} + +static inline bool rxrpc_sending_to_client(const struct rxrpc_txbuf *txb) +{ + return !rxrpc_sending_to_server(txb); +} + #include <trace/events/rxrpc.h> /* * af_rxrpc.c */ -extern atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs; +extern atomic_t rxrpc_n_rx_skbs; extern struct workqueue_struct *rxrpc_workqueue; /* @@ -757,25 +824,31 @@ extern struct workqueue_struct *rxrpc_workqueue; */ int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t); void rxrpc_discard_prealloc(struct rxrpc_sock *); -struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *, - struct rxrpc_sock *, - struct sk_buff *); +bool rxrpc_new_incoming_call(struct rxrpc_local *local, + struct rxrpc_peer *peer, + struct rxrpc_connection *conn, + struct sockaddr_rxrpc *peer_srx, + struct sk_buff *skb); void rxrpc_accept_incoming_calls(struct rxrpc_local *); int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long); /* * call_event.c */ -void rxrpc_propose_ACK(struct rxrpc_call *, u8, u32, bool, bool, - enum rxrpc_propose_ack_trace); -void rxrpc_process_call(struct work_struct *); +void rxrpc_propose_ping(struct rxrpc_call *call, u32 serial, + enum rxrpc_propose_ack_trace why); +void rxrpc_send_ACK(struct rxrpc_call *, u8, rxrpc_serial_t, enum rxrpc_propose_ack_trace); +void rxrpc_propose_delay_ACK(struct rxrpc_call *, rxrpc_serial_t, + enum rxrpc_propose_ack_trace); +void rxrpc_shrink_call_tx_buffer(struct rxrpc_call *); +void rxrpc_resend(struct rxrpc_call *call, struct sk_buff *ack_skb); void rxrpc_reduce_call_timer(struct rxrpc_call *call, unsigned long expire_at, unsigned long now, enum rxrpc_timer_trace why); -void rxrpc_delete_call_timer(struct rxrpc_call *call); +bool rxrpc_input_call_event(struct rxrpc_call *call, struct sk_buff *skb); /* * call_object.c @@ -784,6 +857,7 @@ extern const char *const rxrpc_call_states[]; extern const char *const rxrpc_call_completions[]; extern struct kmem_cache *rxrpc_call_jar; +void rxrpc_poke_call(struct rxrpc_call *call, enum rxrpc_call_poke_trace what); struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *, unsigned long); struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *, gfp_t, unsigned int); struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *, @@ -791,14 +865,13 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *, struct sockaddr_rxrpc *, struct rxrpc_call_params *, gfp_t, unsigned int); +void rxrpc_start_call_timer(struct rxrpc_call *call); void rxrpc_incoming_call(struct rxrpc_sock *, struct rxrpc_call *, struct sk_buff *); void rxrpc_release_call(struct rxrpc_sock *, struct rxrpc_call *); void rxrpc_release_calls_on_socket(struct rxrpc_sock *); -bool __rxrpc_queue_call(struct rxrpc_call *); -bool rxrpc_queue_call(struct rxrpc_call *); -void rxrpc_see_call(struct rxrpc_call *); -bool rxrpc_try_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op); +void rxrpc_see_call(struct rxrpc_call *, enum rxrpc_call_trace); +struct rxrpc_call *rxrpc_try_get_call(struct rxrpc_call *, enum rxrpc_call_trace); void rxrpc_get_call(struct rxrpc_call *, enum rxrpc_call_trace); void rxrpc_put_call(struct rxrpc_call *, enum rxrpc_call_trace); void rxrpc_cleanup_call(struct rxrpc_call *); @@ -815,31 +888,88 @@ static inline bool rxrpc_is_client_call(const struct rxrpc_call *call) } /* + * call_state.c + */ +bool rxrpc_set_call_completion(struct rxrpc_call *call, + enum rxrpc_call_completion compl, + u32 abort_code, + int error); +bool rxrpc_call_completed(struct rxrpc_call *call); +bool rxrpc_abort_call(struct rxrpc_call *call, rxrpc_seq_t seq, + u32 abort_code, int error, enum rxrpc_abort_reason why); +void rxrpc_prefail_call(struct rxrpc_call *call, enum rxrpc_call_completion compl, + int error); + +static inline void rxrpc_set_call_state(struct rxrpc_call *call, + enum rxrpc_call_state state) +{ + /* Order write of completion info before write of ->state. */ + smp_store_release(&call->_state, state); + wake_up(&call->waitq); +} + +static inline enum rxrpc_call_state __rxrpc_call_state(const struct rxrpc_call *call) +{ + return call->_state; /* Only inside I/O thread */ +} + +static inline bool __rxrpc_call_is_complete(const struct rxrpc_call *call) +{ + return __rxrpc_call_state(call) == RXRPC_CALL_COMPLETE; +} + +static inline enum rxrpc_call_state rxrpc_call_state(const struct rxrpc_call *call) +{ + /* Order read ->state before read of completion info. */ + return smp_load_acquire(&call->_state); +} + +static inline bool rxrpc_call_is_complete(const struct rxrpc_call *call) +{ + return rxrpc_call_state(call) == RXRPC_CALL_COMPLETE; +} + +static inline bool rxrpc_call_has_failed(const struct rxrpc_call *call) +{ + return rxrpc_call_is_complete(call) && call->completion != RXRPC_CALL_SUCCEEDED; +} + +/* * conn_client.c */ extern unsigned int rxrpc_reap_client_connections; extern unsigned long rxrpc_conn_idle_client_expiry; extern unsigned long rxrpc_conn_idle_client_fast_expiry; -extern struct idr rxrpc_client_conn_ids; - -void rxrpc_destroy_client_conn_ids(void); -struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *); -void rxrpc_put_bundle(struct rxrpc_bundle *); -int rxrpc_connect_call(struct rxrpc_sock *, struct rxrpc_call *, - struct rxrpc_conn_parameters *, struct sockaddr_rxrpc *, - gfp_t); + +void rxrpc_purge_client_connections(struct rxrpc_local *local); +struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *, enum rxrpc_bundle_trace); +void rxrpc_put_bundle(struct rxrpc_bundle *, enum rxrpc_bundle_trace); +int rxrpc_look_up_bundle(struct rxrpc_call *call, gfp_t gfp); +void rxrpc_connect_client_calls(struct rxrpc_local *local); void rxrpc_expose_client_call(struct rxrpc_call *); void rxrpc_disconnect_client_call(struct rxrpc_bundle *, struct rxrpc_call *); -void rxrpc_put_client_conn(struct rxrpc_connection *); -void rxrpc_discard_expired_client_conns(struct work_struct *); -void rxrpc_destroy_all_client_connections(struct rxrpc_net *); +void rxrpc_deactivate_bundle(struct rxrpc_bundle *bundle); +void rxrpc_put_client_conn(struct rxrpc_connection *, enum rxrpc_conn_trace); +void rxrpc_discard_expired_client_conns(struct rxrpc_local *local); void rxrpc_clean_up_local_conns(struct rxrpc_local *); /* * conn_event.c */ +void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, struct sk_buff *skb, + unsigned int channel); +int rxrpc_abort_conn(struct rxrpc_connection *conn, struct sk_buff *skb, + s32 abort_code, int err, enum rxrpc_abort_reason why); void rxrpc_process_connection(struct work_struct *); void rxrpc_process_delayed_final_acks(struct rxrpc_connection *, bool); +bool rxrpc_input_conn_packet(struct rxrpc_connection *conn, struct sk_buff *skb); +void rxrpc_input_conn_event(struct rxrpc_connection *conn, struct sk_buff *skb); + +static inline bool rxrpc_is_conn_aborted(const struct rxrpc_connection *conn) +{ + /* Order reading the abort info after the state check. */ + return smp_load_acquire(&conn->state) == RXRPC_CONN_ABORTED; +} /* * conn_object.c @@ -847,18 +977,21 @@ void rxrpc_process_delayed_final_acks(struct rxrpc_connection *, bool); extern unsigned int rxrpc_connection_expiry; extern unsigned int rxrpc_closed_conn_expiry; -struct rxrpc_connection *rxrpc_alloc_connection(gfp_t); -struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *, - struct sk_buff *, - struct rxrpc_peer **); +void rxrpc_poke_conn(struct rxrpc_connection *conn, enum rxrpc_conn_trace why); +struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *, gfp_t); +struct rxrpc_connection *rxrpc_find_client_connection_rcu(struct rxrpc_local *, + struct sockaddr_rxrpc *, + struct sk_buff *); void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *); void rxrpc_disconnect_call(struct rxrpc_call *); -void rxrpc_kill_connection(struct rxrpc_connection *); -bool rxrpc_queue_conn(struct rxrpc_connection *); -void rxrpc_see_connection(struct rxrpc_connection *); -struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *); -struct rxrpc_connection *rxrpc_get_connection_maybe(struct rxrpc_connection *); -void rxrpc_put_service_conn(struct rxrpc_connection *); +void rxrpc_kill_client_conn(struct rxrpc_connection *); +void rxrpc_queue_conn(struct rxrpc_connection *, enum rxrpc_conn_trace); +void rxrpc_see_connection(struct rxrpc_connection *, enum rxrpc_conn_trace); +struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *, + enum rxrpc_conn_trace); +struct rxrpc_connection *rxrpc_get_connection_maybe(struct rxrpc_connection *, + enum rxrpc_conn_trace); +void rxrpc_put_connection(struct rxrpc_connection *, enum rxrpc_conn_trace); void rxrpc_service_connection_reaper(struct work_struct *); void rxrpc_destroy_all_connections(struct rxrpc_net *); @@ -872,17 +1005,6 @@ static inline bool rxrpc_conn_is_service(const struct rxrpc_connection *conn) return !rxrpc_conn_is_client(conn); } -static inline void rxrpc_put_connection(struct rxrpc_connection *conn) -{ - if (!conn) - return; - - if (rxrpc_conn_is_client(conn)) - rxrpc_put_client_conn(conn); - else - rxrpc_put_service_conn(conn); -} - static inline void rxrpc_reduce_conn_timer(struct rxrpc_connection *conn, unsigned long expire_at) { @@ -902,7 +1024,27 @@ void rxrpc_unpublish_service_conn(struct rxrpc_connection *); /* * input.c */ -int rxrpc_input_packet(struct sock *, struct sk_buff *); +void rxrpc_congestion_degrade(struct rxrpc_call *); +void rxrpc_input_call_packet(struct rxrpc_call *, struct sk_buff *); +void rxrpc_implicit_end_call(struct rxrpc_call *, struct sk_buff *); + +/* + * io_thread.c + */ +int rxrpc_encap_rcv(struct sock *, struct sk_buff *); +void rxrpc_error_report(struct sock *); +bool rxrpc_direct_abort(struct sk_buff *skb, enum rxrpc_abort_reason why, + s32 abort_code, int err); +int rxrpc_io_thread(void *data); +static inline void rxrpc_wake_up_io_thread(struct rxrpc_local *local) +{ + wake_up_process(local->io_thread); +} + +static inline bool rxrpc_protocol_error(struct sk_buff *skb, enum rxrpc_abort_reason why) +{ + return rxrpc_direct_abort(skb, why, RX_PROTOCOL_ERROR, -EPROTO); +} /* * insecure.c @@ -921,43 +1063,53 @@ int rxrpc_get_server_data_key(struct rxrpc_connection *, const void *, time64_t, /* * local_event.c */ -extern void rxrpc_process_local_events(struct rxrpc_local *); +void rxrpc_send_version_request(struct rxrpc_local *local, + struct rxrpc_host_header *hdr, + struct sk_buff *skb); /* * local_object.c */ struct rxrpc_local *rxrpc_lookup_local(struct net *, const struct sockaddr_rxrpc *); -struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *); -struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *); -void rxrpc_put_local(struct rxrpc_local *); -struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *); -void rxrpc_unuse_local(struct rxrpc_local *); -void rxrpc_queue_local(struct rxrpc_local *); +struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *, enum rxrpc_local_trace); +struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *, enum rxrpc_local_trace); +void rxrpc_put_local(struct rxrpc_local *, enum rxrpc_local_trace); +struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *, enum rxrpc_local_trace); +void rxrpc_unuse_local(struct rxrpc_local *, enum rxrpc_local_trace); +void rxrpc_destroy_local(struct rxrpc_local *local); void rxrpc_destroy_all_locals(struct rxrpc_net *); -static inline bool __rxrpc_unuse_local(struct rxrpc_local *local) +static inline bool __rxrpc_use_local(struct rxrpc_local *local, + enum rxrpc_local_trace why) { - return atomic_dec_return(&local->active_users) == 0; + int r, u; + + r = refcount_read(&local->ref); + u = atomic_fetch_add_unless(&local->active_users, 1, 0); + trace_rxrpc_local(local->debug_id, why, r, u); + return u != 0; } -static inline bool __rxrpc_use_local(struct rxrpc_local *local) +static inline void rxrpc_see_local(struct rxrpc_local *local, + enum rxrpc_local_trace why) { - return atomic_fetch_add_unless(&local->active_users, 1, 0) != 0; + int r, u; + + r = refcount_read(&local->ref); + u = atomic_read(&local->active_users); + trace_rxrpc_local(local->debug_id, why, r, u); } /* * misc.c */ extern unsigned int rxrpc_max_backlog __read_mostly; -extern unsigned long rxrpc_requested_ack_delay; extern unsigned long rxrpc_soft_ack_delay; extern unsigned long rxrpc_idle_ack_delay; extern unsigned int rxrpc_rx_window_size; extern unsigned int rxrpc_rx_mtu; extern unsigned int rxrpc_rx_jumbo_max; -extern const s8 rxrpc_ack_priority[]; - /* * net_ns.c */ @@ -972,17 +1124,18 @@ static inline struct rxrpc_net *rxrpc_net(struct net *net) /* * output.c */ -int rxrpc_send_ack_packet(struct rxrpc_call *, bool, rxrpc_serial_t *); +int rxrpc_send_ack_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb); int rxrpc_send_abort_packet(struct rxrpc_call *); -int rxrpc_send_data_packet(struct rxrpc_call *, struct sk_buff *, bool); -void rxrpc_reject_packets(struct rxrpc_local *); +int rxrpc_send_data_packet(struct rxrpc_call *, struct rxrpc_txbuf *); +void rxrpc_send_conn_abort(struct rxrpc_connection *conn); +void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb); void rxrpc_send_keepalive(struct rxrpc_peer *); +void rxrpc_transmit_one(struct rxrpc_call *call, struct rxrpc_txbuf *txb); /* * peer_event.c */ -void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, unsigned int udp_offset); -void rxrpc_error_report(struct sock *); +void rxrpc_input_error(struct rxrpc_local *, struct sk_buff *); void rxrpc_peer_keepalive_worker(struct work_struct *); /* @@ -990,16 +1143,15 @@ void rxrpc_peer_keepalive_worker(struct work_struct *); */ struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *, const struct sockaddr_rxrpc *); -struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *, struct rxrpc_local *, - struct sockaddr_rxrpc *, gfp_t); -struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *, gfp_t); -void rxrpc_new_incoming_peer(struct rxrpc_sock *, struct rxrpc_local *, - struct rxrpc_peer *); +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_local *local, + struct sockaddr_rxrpc *srx, gfp_t gfp); +struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *, gfp_t, + enum rxrpc_peer_trace); +void rxrpc_new_incoming_peer(struct rxrpc_local *local, struct rxrpc_peer *peer); void rxrpc_destroy_all_peers(struct rxrpc_net *); -struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *); -struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *); -void rxrpc_put_peer(struct rxrpc_peer *); -void rxrpc_put_peer_locked(struct rxrpc_peer *); +struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *, enum rxrpc_peer_trace); +struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *, enum rxrpc_peer_trace); +void rxrpc_put_peer(struct rxrpc_peer *, enum rxrpc_peer_trace); /* * proc.c @@ -1013,33 +1165,22 @@ extern const struct seq_operations rxrpc_local_seq_ops; * recvmsg.c */ void rxrpc_notify_socket(struct rxrpc_call *); -bool __rxrpc_set_call_completion(struct rxrpc_call *, enum rxrpc_call_completion, u32, int); -bool rxrpc_set_call_completion(struct rxrpc_call *, enum rxrpc_call_completion, u32, int); -bool __rxrpc_call_completed(struct rxrpc_call *); -bool rxrpc_call_completed(struct rxrpc_call *); -bool __rxrpc_abort_call(const char *, struct rxrpc_call *, rxrpc_seq_t, u32, int); -bool rxrpc_abort_call(const char *, struct rxrpc_call *, rxrpc_seq_t, u32, int); int rxrpc_recvmsg(struct socket *, struct msghdr *, size_t, int); /* * Abort a call due to a protocol error. */ -static inline bool __rxrpc_abort_eproto(struct rxrpc_call *call, - struct sk_buff *skb, - const char *eproto_why, - const char *why, - u32 abort_code) +static inline int rxrpc_abort_eproto(struct rxrpc_call *call, + struct sk_buff *skb, + s32 abort_code, + enum rxrpc_abort_reason why) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - trace_rxrpc_rx_eproto(call, sp->hdr.serial, eproto_why); - return rxrpc_abort_call(why, call, sp->hdr.seq, abort_code, -EPROTO); + rxrpc_abort_call(call, sp->hdr.seq, abort_code, -EPROTO, why); + return -EPROTO; } -#define rxrpc_abort_eproto(call, skb, eproto_why, abort_why, abort_code) \ - __rxrpc_abort_eproto((call), (skb), tracepoint_string(eproto_why), \ - (abort_why), (abort_code)) - /* * rtt.c */ @@ -1061,6 +1202,7 @@ extern const struct rxrpc_security rxkad; int __init rxrpc_init_security(void); const struct rxrpc_security *rxrpc_security_lookup(u8); void rxrpc_exit_security(void); +int rxrpc_init_client_call_security(struct rxrpc_call *); int rxrpc_init_client_conn_security(struct rxrpc_connection *); const struct rxrpc_security *rxrpc_get_incoming_security(struct rxrpc_sock *, struct sk_buff *); @@ -1070,6 +1212,8 @@ struct key *rxrpc_look_up_server_security(struct rxrpc_connection *, /* * sendmsg.c */ +bool rxrpc_propose_abort(struct rxrpc_call *call, s32 abort_code, int error, + enum rxrpc_abort_reason why); int rxrpc_do_sendmsg(struct rxrpc_sock *, struct msghdr *, size_t); /* @@ -1083,7 +1227,6 @@ int rxrpc_server_keyring(struct rxrpc_sock *, sockptr_t, int); * skbuff.c */ void rxrpc_kernel_data_consumed(struct rxrpc_call *, struct sk_buff *); -void rxrpc_packet_destructor(struct sk_buff *); void rxrpc_new_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_see_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_eaten_skb(struct sk_buff *, enum rxrpc_skb_trace); @@ -1092,6 +1235,15 @@ void rxrpc_free_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_purge_queue(struct sk_buff_head *); /* + * stats.c + */ +int rxrpc_stats_show(struct seq_file *seq, void *v); +int rxrpc_stats_clear(struct file *file, char *buf, size_t size); + +#define rxrpc_inc_stat(rxnet, s) atomic_inc(&(rxnet)->s) +#define rxrpc_dec_stat(rxnet, s) atomic_dec(&(rxnet)->s) + +/* * sysctl.c */ #ifdef CONFIG_SYSCTL @@ -1103,6 +1255,16 @@ static inline void rxrpc_sysctl_exit(void) {} #endif /* + * txbuf.c + */ +extern atomic_t rxrpc_nr_txbuf; +struct rxrpc_txbuf *rxrpc_alloc_txbuf(struct rxrpc_call *call, u8 packet_type, + gfp_t gfp); +void rxrpc_get_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what); +void rxrpc_see_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what); +void rxrpc_put_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what); + +/* * utils.c */ int rxrpc_extract_addr_from_skb(struct sockaddr_rxrpc *, struct sk_buff *); @@ -1135,23 +1297,17 @@ extern unsigned int rxrpc_debug; #define kenter(FMT,...) dbgprintk("==> %s("FMT")",__func__ ,##__VA_ARGS__) #define kleave(FMT,...) dbgprintk("<== %s()"FMT"",__func__ ,##__VA_ARGS__) #define kdebug(FMT,...) dbgprintk(" "FMT ,##__VA_ARGS__) -#define kproto(FMT,...) dbgprintk("### "FMT ,##__VA_ARGS__) -#define knet(FMT,...) dbgprintk("@@@ "FMT ,##__VA_ARGS__) #if defined(__KDEBUG) #define _enter(FMT,...) kenter(FMT,##__VA_ARGS__) #define _leave(FMT,...) kleave(FMT,##__VA_ARGS__) #define _debug(FMT,...) kdebug(FMT,##__VA_ARGS__) -#define _proto(FMT,...) kproto(FMT,##__VA_ARGS__) -#define _net(FMT,...) knet(FMT,##__VA_ARGS__) #elif defined(CONFIG_AF_RXRPC_DEBUG) #define RXRPC_DEBUG_KENTER 0x01 #define RXRPC_DEBUG_KLEAVE 0x02 #define RXRPC_DEBUG_KDEBUG 0x04 -#define RXRPC_DEBUG_KPROTO 0x08 -#define RXRPC_DEBUG_KNET 0x10 #define _enter(FMT,...) \ do { \ @@ -1171,24 +1327,10 @@ do { \ kdebug(FMT,##__VA_ARGS__); \ } while (0) -#define _proto(FMT,...) \ -do { \ - if (unlikely(rxrpc_debug & RXRPC_DEBUG_KPROTO)) \ - kproto(FMT,##__VA_ARGS__); \ -} while (0) - -#define _net(FMT,...) \ -do { \ - if (unlikely(rxrpc_debug & RXRPC_DEBUG_KNET)) \ - knet(FMT,##__VA_ARGS__); \ -} while (0) - #else #define _enter(FMT,...) no_printk("==> %s("FMT")",__func__ ,##__VA_ARGS__) #define _leave(FMT,...) no_printk("<== %s()"FMT"",__func__ ,##__VA_ARGS__) #define _debug(FMT,...) no_printk(" "FMT ,##__VA_ARGS__) -#define _proto(FMT,...) no_printk("### "FMT ,##__VA_ARGS__) -#define _net(FMT,...) no_printk("@@@ "FMT ,##__VA_ARGS__) #endif /* diff --git a/net/rxrpc/call_accept.c b/net/rxrpc/call_accept.c index 99e10eea3732..3e8689fdc437 100644 --- a/net/rxrpc/call_accept.c +++ b/net/rxrpc/call_accept.c @@ -38,7 +38,6 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, unsigned long user_call_ID, gfp_t gfp, unsigned int debug_id) { - const void *here = __builtin_return_address(0); struct rxrpc_call *call, *xcall; struct rxrpc_net *rxnet = rxrpc_net(sock_net(&rx->sk)); struct rb_node *parent, **pp; @@ -70,7 +69,9 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, head = b->peer_backlog_head; tail = READ_ONCE(b->peer_backlog_tail); if (CIRC_CNT(head, tail, size) < max) { - struct rxrpc_peer *peer = rxrpc_alloc_peer(rx->local, gfp); + struct rxrpc_peer *peer; + + peer = rxrpc_alloc_peer(rx->local, gfp, rxrpc_peer_new_prealloc); if (!peer) return -ENOMEM; b->peer_backlog[head] = peer; @@ -89,9 +90,6 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, b->conn_backlog[head] = conn; smp_store_release(&b->conn_backlog_head, (head + 1) & (size - 1)); - - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_service, - refcount_read(&conn->ref), here); } /* Now it gets complicated, because calls get registered with the @@ -101,11 +99,11 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, if (!call) return -ENOMEM; call->flags |= (1 << RXRPC_CALL_IS_SERVICE); - call->state = RXRPC_CALL_SERVER_PREALLOC; + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_PREALLOC); + __set_bit(RXRPC_CALL_EV_INITIAL_PING, &call->events); - trace_rxrpc_call(call->debug_id, rxrpc_call_new_service, - refcount_read(&call->ref), - here, (const void *)user_call_ID); + trace_rxrpc_call(call->debug_id, refcount_read(&call->ref), + user_call_ID, rxrpc_call_new_prealloc_service); write_lock(&rx->call_lock); @@ -126,11 +124,11 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, call->user_call_ID = user_call_ID; call->notify_rx = notify_rx; if (user_attach_call) { - rxrpc_get_call(call, rxrpc_call_got_kernel); + rxrpc_get_call(call, rxrpc_call_get_kernel_service); user_attach_call(call, user_call_ID); } - rxrpc_get_call(call, rxrpc_call_got_userid); + rxrpc_get_call(call, rxrpc_call_get_userid); rb_link_node(&call->sock_node, parent, pp); rb_insert_color(&call->sock_node, &rx->calls); set_bit(RXRPC_CALL_HAS_USERID, &call->flags); @@ -140,9 +138,9 @@ static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, write_unlock(&rx->call_lock); rxnet = call->rxnet; - spin_lock_bh(&rxnet->call_lock); + spin_lock(&rxnet->call_lock); list_add_tail_rcu(&call->link, &rxnet->calls); - spin_unlock_bh(&rxnet->call_lock); + spin_unlock(&rxnet->call_lock); b->call_backlog[call_head] = call; smp_store_release(&b->call_backlog_head, (call_head + 1) & (size - 1)); @@ -190,14 +188,14 @@ void rxrpc_discard_prealloc(struct rxrpc_sock *rx) /* Make sure that there aren't any incoming calls in progress before we * clear the preallocation buffers. */ - spin_lock_bh(&rx->incoming_lock); - spin_unlock_bh(&rx->incoming_lock); + spin_lock(&rx->incoming_lock); + spin_unlock(&rx->incoming_lock); head = b->peer_backlog_head; tail = b->peer_backlog_tail; while (CIRC_CNT(head, tail, size) > 0) { struct rxrpc_peer *peer = b->peer_backlog[tail]; - rxrpc_put_local(peer->local); + rxrpc_put_local(peer->local, rxrpc_local_put_prealloc_conn); kfree(peer); tail = (tail + 1) & (size - 1); } @@ -230,7 +228,7 @@ void rxrpc_discard_prealloc(struct rxrpc_sock *rx) } rxrpc_call_completed(call); rxrpc_release_call(rx, call); - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_discard_prealloc); tail = (tail + 1) & (size - 1); } @@ -238,22 +236,6 @@ void rxrpc_discard_prealloc(struct rxrpc_sock *rx) } /* - * Ping the other end to fill our RTT cache and to retrieve the rwind - * and MTU parameters. - */ -static void rxrpc_send_ping(struct rxrpc_call *call, struct sk_buff *skb) -{ - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - ktime_t now = skb->tstamp; - - if (call->peer->rtt_count < 3 || - ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), now)) - rxrpc_propose_ACK(call, RXRPC_ACK_PING, sp->hdr.serial, - true, true, - rxrpc_propose_ack_ping_for_params); -} - -/* * Allocate a new incoming call from the prealloc pool, along with a connection * and a peer as necessary. */ @@ -262,6 +244,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, struct rxrpc_peer *peer, struct rxrpc_connection *conn, const struct rxrpc_security *sec, + struct sockaddr_rxrpc *peer_srx, struct sk_buff *skb) { struct rxrpc_backlog *b = rx->backlog; @@ -287,18 +270,17 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, return NULL; if (!conn) { - if (peer && !rxrpc_get_peer_maybe(peer)) + if (peer && !rxrpc_get_peer_maybe(peer, rxrpc_peer_get_service_conn)) peer = NULL; if (!peer) { peer = b->peer_backlog[peer_tail]; - if (rxrpc_extract_addr_from_skb(&peer->srx, skb) < 0) - return NULL; + peer->srx = *peer_srx; b->peer_backlog[peer_tail] = NULL; smp_store_release(&b->peer_backlog_tail, (peer_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); - rxrpc_new_incoming_peer(rx, local, peer); + rxrpc_new_incoming_peer(local, peer); } /* Now allocate and set up the connection */ @@ -306,12 +288,13 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, b->conn_backlog[conn_tail] = NULL; smp_store_release(&b->conn_backlog_tail, (conn_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); - conn->params.local = rxrpc_get_local(local); - conn->params.peer = peer; - rxrpc_see_connection(conn); + conn->local = rxrpc_get_local(local, rxrpc_local_get_prealloc_conn); + conn->peer = peer; + rxrpc_see_connection(conn, rxrpc_conn_see_new_service_conn); rxrpc_new_incoming_connection(rx, conn, sec, skb); } else { - rxrpc_get_connection(conn); + rxrpc_get_connection(conn, rxrpc_conn_get_service_conn); + atomic_inc(&conn->active); } /* And now we can allocate and set up a new call */ @@ -320,66 +303,78 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, smp_store_release(&b->call_backlog_tail, (call_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); - rxrpc_see_call(call); + rxrpc_see_call(call, rxrpc_call_see_accept); + call->local = rxrpc_get_local(conn->local, rxrpc_local_get_call); call->conn = conn; call->security = conn->security; call->security_ix = conn->security_ix; - call->peer = rxrpc_get_peer(conn->params.peer); - call->cong_cwnd = call->peer->cong_cwnd; + call->peer = rxrpc_get_peer(conn->peer, rxrpc_peer_get_accept); + call->dest_srx = peer->srx; + call->cong_ssthresh = call->peer->cong_ssthresh; + call->tx_last_sent = ktime_get_real(); return call; } /* - * Set up a new incoming call. Called in BH context with the RCU read lock - * held. + * Set up a new incoming call. Called from the I/O thread. * * If this is for a kernel service, when we allocate the call, it will have * three refs on it: (1) the kernel service, (2) the user_call_ID tree, (3) the * retainer ref obtained from the backlog buffer. Prealloc calls for userspace - * services only have the ref from the backlog buffer. We want to pass this - * ref to non-BH context to dispose of. + * services only have the ref from the backlog buffer. * * If we want to report an error, we mark the skb with the packet type and - * abort code and return NULL. - * - * The call is returned with the user access mutex held. + * abort code and return false. */ -struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, - struct rxrpc_sock *rx, - struct sk_buff *skb) +bool rxrpc_new_incoming_call(struct rxrpc_local *local, + struct rxrpc_peer *peer, + struct rxrpc_connection *conn, + struct sockaddr_rxrpc *peer_srx, + struct sk_buff *skb) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); const struct rxrpc_security *sec = NULL; - struct rxrpc_connection *conn; - struct rxrpc_peer *peer = NULL; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_call *call = NULL; + struct rxrpc_sock *rx; _enter(""); - spin_lock(&rx->incoming_lock); - if (rx->sk.sk_state == RXRPC_SERVER_LISTEN_DISABLED || - rx->sk.sk_state == RXRPC_CLOSE) { - trace_rxrpc_abort(0, "CLS", sp->hdr.cid, sp->hdr.callNumber, - sp->hdr.seq, RX_INVALID_OPERATION, ESHUTDOWN); - skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; - skb->priority = RX_INVALID_OPERATION; - goto no_call; - } + /* Don't set up a call for anything other than a DATA packet. */ + if (sp->hdr.type != RXRPC_PACKET_TYPE_DATA) + return rxrpc_protocol_error(skb, rxrpc_eproto_no_service_call); - /* The peer, connection and call may all have sprung into existence due - * to a duplicate packet being handled on another CPU in parallel, so - * we have to recheck the routing. However, we're now holding - * rx->incoming_lock, so the values should remain stable. + read_lock(&local->services_lock); + + /* Weed out packets to services we're not offering. Packets that would + * begin a call are explicitly rejected and the rest are just + * discarded. */ - conn = rxrpc_find_connection_rcu(local, skb, &peer); + rx = local->service; + if (!rx || (sp->hdr.serviceId != rx->srx.srx_service && + sp->hdr.serviceId != rx->second_service) + ) { + if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA && + sp->hdr.seq == 1) + goto unsupported_service; + goto discard; + } if (!conn) { sec = rxrpc_get_incoming_security(rx, skb); if (!sec) - goto no_call; + goto unsupported_security; + } + + spin_lock(&rx->incoming_lock); + if (rx->sk.sk_state == RXRPC_SERVER_LISTEN_DISABLED || + rx->sk.sk_state == RXRPC_CLOSE) { + rxrpc_direct_abort(skb, rxrpc_abort_shut_down, + RX_INVALID_OPERATION, -ESHUTDOWN); + goto no_call; } - call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, skb); + call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, peer_srx, + skb); if (!call) { skb->mark = RXRPC_SKB_MARK_REJECT_BUSY; goto no_call; @@ -396,50 +391,43 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, rx->notify_new_call(&rx->sk, call, call->user_call_ID); spin_lock(&conn->state_lock); - switch (conn->state) { - case RXRPC_CONN_SERVICE_UNSECURED: + if (conn->state == RXRPC_CONN_SERVICE_UNSECURED) { conn->state = RXRPC_CONN_SERVICE_CHALLENGING; set_bit(RXRPC_CONN_EV_CHALLENGE, &call->conn->events); - rxrpc_queue_conn(call->conn); - break; - - case RXRPC_CONN_SERVICE: - write_lock(&call->state_lock); - if (call->state < RXRPC_CALL_COMPLETE) - call->state = RXRPC_CALL_SERVER_RECV_REQUEST; - write_unlock(&call->state_lock); - break; - - case RXRPC_CONN_REMOTELY_ABORTED: - rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, - conn->abort_code, conn->error); - break; - case RXRPC_CONN_LOCALLY_ABORTED: - rxrpc_abort_call("CON", call, sp->hdr.seq, - conn->abort_code, conn->error); - break; - default: - BUG(); + rxrpc_queue_conn(call->conn, rxrpc_conn_queue_challenge); } spin_unlock(&conn->state_lock); - spin_unlock(&rx->incoming_lock); - rxrpc_send_ping(call, skb); + spin_unlock(&rx->incoming_lock); + read_unlock(&local->services_lock); - /* We have to discard the prealloc queue's ref here and rely on a - * combination of the RCU read lock and refs held either by the socket - * (recvmsg queue, to-be-accepted queue or user ID tree) or the kernel - * service to prevent the call from being deallocated too early. - */ - rxrpc_put_call(call, rxrpc_call_put); + if (hlist_unhashed(&call->error_link)) { + spin_lock(&call->peer->lock); + hlist_add_head(&call->error_link, &call->peer->error_targets); + spin_unlock(&call->peer->lock); + } _leave(" = %p{%d}", call, call->debug_id); - return call; - + rxrpc_input_call_event(call, skb); + rxrpc_put_call(call, rxrpc_call_put_input); + return true; + +unsupported_service: + read_unlock(&local->services_lock); + return rxrpc_direct_abort(skb, rxrpc_abort_service_not_offered, + RX_INVALID_OPERATION, -EOPNOTSUPP); +unsupported_security: + read_unlock(&local->services_lock); + return rxrpc_direct_abort(skb, rxrpc_abort_service_not_offered, + RX_INVALID_OPERATION, -EKEYREJECTED); no_call: spin_unlock(&rx->incoming_lock); - _leave(" = NULL [%u]", skb->mark); - return NULL; + read_unlock(&local->services_lock); + _leave(" = f [%u]", skb->mark); + return false; +discard: + read_unlock(&local->services_lock); + return true; } /* diff --git a/net/rxrpc/call_event.c b/net/rxrpc/call_event.c index 2a93e7b5fbd0..1abdef15debc 100644 --- a/net/rxrpc/call_event.c +++ b/net/rxrpc/call_event.c @@ -20,127 +20,84 @@ /* * Propose a PING ACK be sent. */ -static void rxrpc_propose_ping(struct rxrpc_call *call, - bool immediate, bool background) +void rxrpc_propose_ping(struct rxrpc_call *call, u32 serial, + enum rxrpc_propose_ack_trace why) { - if (immediate) { - if (background && - !test_and_set_bit(RXRPC_CALL_EV_PING, &call->events)) - rxrpc_queue_call(call); - } else { - unsigned long now = jiffies; - unsigned long ping_at = now + rxrpc_idle_ack_delay; - - if (time_before(ping_at, call->ping_at)) { - WRITE_ONCE(call->ping_at, ping_at); - rxrpc_reduce_call_timer(call, ping_at, now, - rxrpc_timer_set_for_ping); - } + unsigned long now = jiffies; + unsigned long ping_at = now + rxrpc_idle_ack_delay; + + if (time_before(ping_at, call->ping_at)) { + WRITE_ONCE(call->ping_at, ping_at); + rxrpc_reduce_call_timer(call, ping_at, now, + rxrpc_timer_set_for_ping); + trace_rxrpc_propose_ack(call, why, RXRPC_ACK_PING, serial); } } /* - * propose an ACK be sent + * Propose a DELAY ACK be sent in the future. */ -static void __rxrpc_propose_ACK(struct rxrpc_call *call, u8 ack_reason, - u32 serial, bool immediate, bool background, - enum rxrpc_propose_ack_trace why) +void rxrpc_propose_delay_ACK(struct rxrpc_call *call, rxrpc_serial_t serial, + enum rxrpc_propose_ack_trace why) { - enum rxrpc_propose_ack_outcome outcome = rxrpc_propose_ack_use; unsigned long expiry = rxrpc_soft_ack_delay; - s8 prior = rxrpc_ack_priority[ack_reason]; - - /* Pings are handled specially because we don't want to accidentally - * lose a ping response by subsuming it into a ping. - */ - if (ack_reason == RXRPC_ACK_PING) { - rxrpc_propose_ping(call, immediate, background); - goto trace; + unsigned long now = jiffies, ack_at; + + call->ackr_serial = serial; + + if (rxrpc_soft_ack_delay < expiry) + expiry = rxrpc_soft_ack_delay; + if (call->peer->srtt_us != 0) + ack_at = usecs_to_jiffies(call->peer->srtt_us >> 3); + else + ack_at = expiry; + + ack_at += READ_ONCE(call->tx_backoff); + ack_at += now; + if (time_before(ack_at, call->delay_ack_at)) { + WRITE_ONCE(call->delay_ack_at, ack_at); + rxrpc_reduce_call_timer(call, ack_at, now, + rxrpc_timer_set_for_ack); } - /* Update DELAY, IDLE, REQUESTED and PING_RESPONSE ACK serial - * numbers, but we don't alter the timeout. - */ - _debug("prior %u %u vs %u %u", - ack_reason, prior, - call->ackr_reason, rxrpc_ack_priority[call->ackr_reason]); - if (ack_reason == call->ackr_reason) { - if (RXRPC_ACK_UPDATEABLE & (1 << ack_reason)) { - outcome = rxrpc_propose_ack_update; - call->ackr_serial = serial; - } - if (!immediate) - goto trace; - } else if (prior > rxrpc_ack_priority[call->ackr_reason]) { - call->ackr_reason = ack_reason; - call->ackr_serial = serial; - } else { - outcome = rxrpc_propose_ack_subsume; - } + trace_rxrpc_propose_ack(call, why, RXRPC_ACK_DELAY, serial); +} - switch (ack_reason) { - case RXRPC_ACK_REQUESTED: - if (rxrpc_requested_ack_delay < expiry) - expiry = rxrpc_requested_ack_delay; - if (serial == 1) - immediate = false; - break; +/* + * Queue an ACK for immediate transmission. + */ +void rxrpc_send_ACK(struct rxrpc_call *call, u8 ack_reason, + rxrpc_serial_t serial, enum rxrpc_propose_ack_trace why) +{ + struct rxrpc_txbuf *txb; - case RXRPC_ACK_DELAY: - if (rxrpc_soft_ack_delay < expiry) - expiry = rxrpc_soft_ack_delay; - break; + if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) + return; - case RXRPC_ACK_IDLE: - if (rxrpc_idle_ack_delay < expiry) - expiry = rxrpc_idle_ack_delay; - break; + rxrpc_inc_stat(call->rxnet, stat_tx_acks[ack_reason]); - default: - immediate = true; - break; + txb = rxrpc_alloc_txbuf(call, RXRPC_PACKET_TYPE_ACK, + rcu_read_lock_held() ? GFP_ATOMIC | __GFP_NOWARN : GFP_NOFS); + if (!txb) { + kleave(" = -ENOMEM"); + return; } - if (test_bit(RXRPC_CALL_EV_ACK, &call->events)) { - _debug("already scheduled"); - } else if (immediate || expiry == 0) { - _debug("immediate ACK %lx", call->events); - if (!test_and_set_bit(RXRPC_CALL_EV_ACK, &call->events) && - background) - rxrpc_queue_call(call); - } else { - unsigned long now = jiffies, ack_at; - - if (call->peer->srtt_us != 0) - ack_at = usecs_to_jiffies(call->peer->srtt_us >> 3); - else - ack_at = expiry; - - ack_at += READ_ONCE(call->tx_backoff); - ack_at += now; - if (time_before(ack_at, call->ack_at)) { - WRITE_ONCE(call->ack_at, ack_at); - rxrpc_reduce_call_timer(call, ack_at, now, - rxrpc_timer_set_for_ack); - } - } - -trace: - trace_rxrpc_propose_ack(call, why, ack_reason, serial, immediate, - background, outcome); -} - -/* - * propose an ACK be sent, locking the call structure - */ -void rxrpc_propose_ACK(struct rxrpc_call *call, u8 ack_reason, - u32 serial, bool immediate, bool background, - enum rxrpc_propose_ack_trace why) -{ - spin_lock_bh(&call->lock); - __rxrpc_propose_ACK(call, ack_reason, serial, - immediate, background, why); - spin_unlock_bh(&call->lock); + txb->ack_why = why; + txb->wire.seq = 0; + txb->wire.type = RXRPC_PACKET_TYPE_ACK; + txb->wire.flags |= RXRPC_SLOW_START_OK; + txb->ack.bufferSpace = 0; + txb->ack.maxSkew = 0; + txb->ack.firstPacket = 0; + txb->ack.previousPacket = 0; + txb->ack.serial = htonl(serial); + txb->ack.reason = ack_reason; + txb->ack.nAcks = 0; + + trace_rxrpc_send_ack(call, why, ack_reason, serial); + rxrpc_send_ack_packet(call, txb); + rxrpc_put_txbuf(txb, rxrpc_txbuf_put_ack_tx); } /* @@ -154,64 +111,115 @@ static void rxrpc_congestion_timeout(struct rxrpc_call *call) /* * Perform retransmission of NAK'd and unack'd packets. */ -static void rxrpc_resend(struct rxrpc_call *call, unsigned long now_j) +void rxrpc_resend(struct rxrpc_call *call, struct sk_buff *ack_skb) { - struct sk_buff *skb; + struct rxrpc_ackpacket *ack = NULL; + struct rxrpc_txbuf *txb; unsigned long resend_at; - rxrpc_seq_t cursor, seq, top; + rxrpc_seq_t transmitted = READ_ONCE(call->tx_transmitted); ktime_t now, max_age, oldest, ack_ts; - int ix; - u8 annotation, anno_type, retrans = 0, unacked = 0; + bool unacked = false; + unsigned int i; + LIST_HEAD(retrans_queue); - _enter("{%d,%d}", call->tx_hard_ack, call->tx_top); + _enter("{%d,%d}", call->acks_hard_ack, call->tx_top); now = ktime_get_real(); max_age = ktime_sub_us(now, jiffies_to_usecs(call->peer->rto_j)); + oldest = now; - spin_lock_bh(&call->lock); + if (list_empty(&call->tx_buffer)) + goto no_resend; - cursor = call->tx_hard_ack; - top = call->tx_top; - ASSERT(before_eq(cursor, top)); - if (cursor == top) - goto out_unlock; + if (list_empty(&call->tx_buffer)) + goto no_further_resend; - /* Scan the packet list without dropping the lock and decide which of - * the packets in the Tx buffer we're going to resend and what the new - * resend timeout will be. + trace_rxrpc_resend(call, ack_skb); + txb = list_first_entry(&call->tx_buffer, struct rxrpc_txbuf, call_link); + + /* Scan the soft ACK table without dropping the lock and resend any + * explicitly NAK'd packets. */ - trace_rxrpc_resend(call, (cursor + 1) & RXRPC_RXTX_BUFF_MASK); - oldest = now; - for (seq = cursor + 1; before_eq(seq, top); seq++) { - ix = seq & RXRPC_RXTX_BUFF_MASK; - annotation = call->rxtx_annotations[ix]; - anno_type = annotation & RXRPC_TX_ANNO_MASK; - annotation &= ~RXRPC_TX_ANNO_MASK; - if (anno_type == RXRPC_TX_ANNO_ACK) - continue; + if (ack_skb) { + ack = (void *)ack_skb->data + sizeof(struct rxrpc_wire_header); - skb = call->rxtx_buffer[ix]; - rxrpc_see_skb(skb, rxrpc_skb_seen); + for (i = 0; i < ack->nAcks; i++) { + rxrpc_seq_t seq; - if (anno_type == RXRPC_TX_ANNO_UNACK) { - if (ktime_after(skb->tstamp, max_age)) { - if (ktime_before(skb->tstamp, oldest)) - oldest = skb->tstamp; + if (ack->acks[i] & 1) continue; + seq = ntohl(ack->firstPacket) + i; + if (after(txb->seq, transmitted)) + break; + if (after(txb->seq, seq)) + continue; /* A new hard ACK probably came in */ + list_for_each_entry_from(txb, &call->tx_buffer, call_link) { + if (txb->seq == seq) + goto found_txb; + } + goto no_further_resend; + + found_txb: + if (after(ntohl(txb->wire.serial), call->acks_highest_serial)) + continue; /* Ack point not yet reached */ + + rxrpc_see_txbuf(txb, rxrpc_txbuf_see_unacked); + + if (list_empty(&txb->tx_link)) { + list_add_tail(&txb->tx_link, &retrans_queue); + set_bit(RXRPC_TXBUF_RESENT, &txb->flags); } - if (!(annotation & RXRPC_TX_ANNO_RESENT)) - unacked++; + + trace_rxrpc_retransmit(call, txb->seq, + ktime_to_ns(ktime_sub(txb->last_sent, + max_age))); + + if (list_is_last(&txb->call_link, &call->tx_buffer)) + goto no_further_resend; + txb = list_next_entry(txb, call_link); } + } + + /* Fast-forward through the Tx queue to the point the peer says it has + * seen. Anything between the soft-ACK table and that point will get + * ACK'd or NACK'd in due course, so don't worry about it here; here we + * need to consider retransmitting anything beyond that point. + * + * Note that ACK for a packet can beat the update of tx_transmitted. + */ + if (after_eq(READ_ONCE(call->acks_prev_seq), READ_ONCE(call->tx_transmitted))) + goto no_further_resend; - /* Okay, we need to retransmit a packet. */ - call->rxtx_annotations[ix] = RXRPC_TX_ANNO_RETRANS | annotation; - retrans++; - trace_rxrpc_retransmit(call, seq, annotation | anno_type, - ktime_to_ns(ktime_sub(skb->tstamp, max_age))); + list_for_each_entry_from(txb, &call->tx_buffer, call_link) { + if (before_eq(txb->seq, READ_ONCE(call->acks_prev_seq))) + continue; + if (after(txb->seq, READ_ONCE(call->tx_transmitted))) + break; /* Not transmitted yet */ + + if (ack && ack->reason == RXRPC_ACK_PING_RESPONSE && + before(ntohl(txb->wire.serial), ntohl(ack->serial))) + goto do_resend; /* Wasn't accounted for by a more recent ping. */ + + if (ktime_after(txb->last_sent, max_age)) { + if (ktime_before(txb->last_sent, oldest)) + oldest = txb->last_sent; + continue; + } + + do_resend: + unacked = true; + if (list_empty(&txb->tx_link)) { + list_add_tail(&txb->tx_link, &retrans_queue); + set_bit(RXRPC_TXBUF_RESENT, &txb->flags); + rxrpc_inc_stat(call->rxnet, stat_tx_data_retrans); + } } +no_further_resend: +no_resend: resend_at = nsecs_to_jiffies(ktime_to_ns(ktime_sub(now, oldest))); - resend_at += jiffies + rxrpc_get_rto_backoff(call->peer, retrans); + resend_at += jiffies + rxrpc_get_rto_backoff(call->peer, + !list_empty(&retrans_queue)); WRITE_ONCE(call->resend_at, resend_at); if (unacked) @@ -221,125 +229,203 @@ static void rxrpc_resend(struct rxrpc_call *call, unsigned long now_j) * that an ACK got lost somewhere. Send a ping to find out instead of * retransmitting data. */ - if (!retrans) { - rxrpc_reduce_call_timer(call, resend_at, now_j, + if (list_empty(&retrans_queue)) { + rxrpc_reduce_call_timer(call, resend_at, jiffies, rxrpc_timer_set_for_resend); - spin_unlock_bh(&call->lock); ack_ts = ktime_sub(now, call->acks_latest_ts); if (ktime_to_us(ack_ts) < (call->peer->srtt_us >> 3)) goto out; - rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, false, - rxrpc_propose_ack_ping_for_lost_ack); - rxrpc_send_ack_packet(call, true, NULL); + rxrpc_send_ACK(call, RXRPC_ACK_PING, 0, + rxrpc_propose_ack_ping_for_lost_ack); goto out; } - /* Now go through the Tx window and perform the retransmissions. We - * have to drop the lock for each send. If an ACK comes in whilst the - * lock is dropped, it may clear some of the retransmission markers for - * packets that it soft-ACKs. - */ - for (seq = cursor + 1; before_eq(seq, top); seq++) { - ix = seq & RXRPC_RXTX_BUFF_MASK; - annotation = call->rxtx_annotations[ix]; - anno_type = annotation & RXRPC_TX_ANNO_MASK; - if (anno_type != RXRPC_TX_ANNO_RETRANS) - continue; + /* Retransmit the queue */ + while ((txb = list_first_entry_or_null(&retrans_queue, + struct rxrpc_txbuf, tx_link))) { + list_del_init(&txb->tx_link); + rxrpc_transmit_one(call, txb); + } - /* We need to reset the retransmission state, but we need to do - * so before we drop the lock as a new ACK/NAK may come in and - * confuse things - */ - annotation &= ~RXRPC_TX_ANNO_MASK; - annotation |= RXRPC_TX_ANNO_UNACK | RXRPC_TX_ANNO_RESENT; - call->rxtx_annotations[ix] = annotation; +out: + _leave(""); +} - skb = call->rxtx_buffer[ix]; - if (!skb) - continue; +/* + * Start transmitting the reply to a service. This cancels the need to ACK the + * request if we haven't yet done so. + */ +static void rxrpc_begin_service_reply(struct rxrpc_call *call) +{ + unsigned long now = jiffies; + + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_SEND_REPLY); + WRITE_ONCE(call->delay_ack_at, now + MAX_JIFFY_OFFSET); + if (call->ackr_reason == RXRPC_ACK_DELAY) + call->ackr_reason = 0; + trace_rxrpc_timer(call, rxrpc_timer_init_for_send_reply, now); +} - rxrpc_get_skb(skb, rxrpc_skb_got); - spin_unlock_bh(&call->lock); +/* + * Close the transmission phase. After this point there is no more data to be + * transmitted in the call. + */ +static void rxrpc_close_tx_phase(struct rxrpc_call *call) +{ + _debug("________awaiting reply/ACK__________"); - if (rxrpc_send_data_packet(call, skb, true) < 0) { - rxrpc_free_skb(skb, rxrpc_skb_freed); + switch (__rxrpc_call_state(call)) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + rxrpc_set_call_state(call, RXRPC_CALL_CLIENT_AWAIT_REPLY); + break; + case RXRPC_CALL_SERVER_SEND_REPLY: + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_AWAIT_ACK); + break; + default: + break; + } +} + +static bool rxrpc_tx_window_has_space(struct rxrpc_call *call) +{ + unsigned int winsize = min_t(unsigned int, call->tx_winsize, + call->cong_cwnd + call->cong_extra); + rxrpc_seq_t window = call->acks_hard_ack, wtop = window + winsize; + rxrpc_seq_t tx_top = call->tx_top; + int space; + + space = wtop - tx_top; + return space > 0; +} + +/* + * Decant some if the sendmsg prepared queue into the transmission buffer. + */ +static void rxrpc_decant_prepared_tx(struct rxrpc_call *call) +{ + struct rxrpc_txbuf *txb; + + if (!test_bit(RXRPC_CALL_EXPOSED, &call->flags)) { + if (list_empty(&call->tx_sendmsg)) return; - } + rxrpc_expose_client_call(call); + } + + while ((txb = list_first_entry_or_null(&call->tx_sendmsg, + struct rxrpc_txbuf, call_link))) { + spin_lock(&call->tx_lock); + list_del(&txb->call_link); + spin_unlock(&call->tx_lock); + + call->tx_top = txb->seq; + list_add_tail(&txb->call_link, &call->tx_buffer); + + if (txb->wire.flags & RXRPC_LAST_PACKET) + rxrpc_close_tx_phase(call); - if (rxrpc_is_client_call(call)) - rxrpc_expose_client_call(call); + rxrpc_transmit_one(call, txb); - rxrpc_free_skb(skb, rxrpc_skb_freed); - spin_lock_bh(&call->lock); - if (after(call->tx_hard_ack, seq)) - seq = call->tx_hard_ack; + if (!rxrpc_tx_window_has_space(call)) + break; } +} -out_unlock: - spin_unlock_bh(&call->lock); -out: - _leave(""); +static void rxrpc_transmit_some_data(struct rxrpc_call *call) +{ + switch (__rxrpc_call_state(call)) { + case RXRPC_CALL_SERVER_ACK_REQUEST: + if (list_empty(&call->tx_sendmsg)) + return; + rxrpc_begin_service_reply(call); + fallthrough; + + case RXRPC_CALL_SERVER_SEND_REPLY: + case RXRPC_CALL_CLIENT_SEND_REQUEST: + if (!rxrpc_tx_window_has_space(call)) + return; + if (list_empty(&call->tx_sendmsg)) { + rxrpc_inc_stat(call->rxnet, stat_tx_data_underflow); + return; + } + rxrpc_decant_prepared_tx(call); + break; + default: + return; + } +} + +/* + * Ping the other end to fill our RTT cache and to retrieve the rwind + * and MTU parameters. + */ +static void rxrpc_send_initial_ping(struct rxrpc_call *call) +{ + if (call->peer->rtt_count < 3 || + ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), + ktime_get_real())) + rxrpc_send_ACK(call, RXRPC_ACK_PING, 0, + rxrpc_propose_ack_ping_for_params); } /* * Handle retransmission and deferred ACK/abort generation. */ -void rxrpc_process_call(struct work_struct *work) +bool rxrpc_input_call_event(struct rxrpc_call *call, struct sk_buff *skb) { - struct rxrpc_call *call = - container_of(work, struct rxrpc_call, processor); - rxrpc_serial_t *send_ack; unsigned long now, next, t; - unsigned int iterations = 0; + rxrpc_serial_t ackr_serial; + bool resend = false, expired = false; + s32 abort_code; - rxrpc_see_call(call); + rxrpc_see_call(call, rxrpc_call_see_input); //printk("\n--------------------\n"); _enter("{%d,%s,%lx}", - call->debug_id, rxrpc_call_states[call->state], call->events); + call->debug_id, rxrpc_call_states[__rxrpc_call_state(call)], + call->events); -recheck_state: - /* Limit the number of times we do this before returning to the manager */ - iterations++; - if (iterations > 5) - goto requeue; + if (__rxrpc_call_is_complete(call)) + goto out; - if (test_and_clear_bit(RXRPC_CALL_EV_ABORT, &call->events)) { - rxrpc_send_abort_packet(call); - goto recheck_state; + /* Handle abort request locklessly, vs rxrpc_propose_abort(). */ + abort_code = smp_load_acquire(&call->send_abort); + if (abort_code) { + rxrpc_abort_call(call, 0, call->send_abort, call->send_abort_err, + call->send_abort_why); + goto out; } - if (call->state == RXRPC_CALL_COMPLETE) { - rxrpc_delete_call_timer(call); - goto out_put; - } + if (skb && skb->mark == RXRPC_SKB_MARK_ERROR) + goto out; - /* Work out if any timeouts tripped */ + /* If we see our async-event poke, check for timeout trippage. */ now = jiffies; t = READ_ONCE(call->expect_rx_by); if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_normal, now); - set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + expired = true; } t = READ_ONCE(call->expect_req_by); - if (call->state == RXRPC_CALL_SERVER_RECV_REQUEST && + if (__rxrpc_call_state(call) == RXRPC_CALL_SERVER_RECV_REQUEST && time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_idle, now); - set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + expired = true; } t = READ_ONCE(call->expect_term_by); if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_hard, now); - set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + expired = true; } - t = READ_ONCE(call->ack_at); + t = READ_ONCE(call->delay_ack_at); if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_ack, now); - cmpxchg(&call->ack_at, t, now + MAX_JIFFY_OFFSET); - set_bit(RXRPC_CALL_EV_ACK, &call->events); + cmpxchg(&call->delay_ack_at, t, now + MAX_JIFFY_OFFSET); + ackr_serial = xchg(&call->ackr_serial, 0); + rxrpc_send_ACK(call, RXRPC_ACK_DELAY, ackr_serial, + rxrpc_propose_ack_ping_for_lost_ack); } t = READ_ONCE(call->ack_lost_at); @@ -353,95 +439,100 @@ recheck_state: if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_keepalive, now); cmpxchg(&call->keepalive_at, t, now + MAX_JIFFY_OFFSET); - rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, true, - rxrpc_propose_ack_ping_for_keepalive); - set_bit(RXRPC_CALL_EV_PING, &call->events); + rxrpc_send_ACK(call, RXRPC_ACK_PING, 0, + rxrpc_propose_ack_ping_for_keepalive); } t = READ_ONCE(call->ping_at); if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_ping, now); cmpxchg(&call->ping_at, t, now + MAX_JIFFY_OFFSET); - set_bit(RXRPC_CALL_EV_PING, &call->events); + rxrpc_send_ACK(call, RXRPC_ACK_PING, 0, + rxrpc_propose_ack_ping_for_keepalive); } t = READ_ONCE(call->resend_at); if (time_after_eq(now, t)) { trace_rxrpc_timer(call, rxrpc_timer_exp_resend, now); cmpxchg(&call->resend_at, t, now + MAX_JIFFY_OFFSET); - set_bit(RXRPC_CALL_EV_RESEND, &call->events); + resend = true; + } + + if (skb) + rxrpc_input_call_packet(call, skb); + + rxrpc_transmit_some_data(call); + + if (skb) { + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + if (sp->hdr.type == RXRPC_PACKET_TYPE_ACK) + rxrpc_congestion_degrade(call); } + if (test_and_clear_bit(RXRPC_CALL_EV_INITIAL_PING, &call->events)) + rxrpc_send_initial_ping(call); + /* Process events */ - if (test_and_clear_bit(RXRPC_CALL_EV_EXPIRED, &call->events)) { + if (expired) { if (test_bit(RXRPC_CALL_RX_HEARD, &call->flags) && (int)call->conn->hi_serial - (int)call->rx_serial > 0) { trace_rxrpc_call_reset(call); - rxrpc_abort_call("EXP", call, 0, RX_CALL_DEAD, -ECONNRESET); + rxrpc_abort_call(call, 0, RX_CALL_DEAD, -ECONNRESET, + rxrpc_abort_call_reset); } else { - rxrpc_abort_call("EXP", call, 0, RX_CALL_TIMEOUT, -ETIME); + rxrpc_abort_call(call, 0, RX_CALL_TIMEOUT, -ETIME, + rxrpc_abort_call_timeout); } - set_bit(RXRPC_CALL_EV_ABORT, &call->events); - goto recheck_state; + goto out; } - send_ack = NULL; - if (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events)) { - call->acks_lost_top = call->tx_top; - rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, false, - rxrpc_propose_ack_ping_for_lost_ack); - send_ack = &call->acks_lost_ping; - } + if (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events)) + rxrpc_send_ACK(call, RXRPC_ACK_PING, 0, + rxrpc_propose_ack_ping_for_lost_ack); - if (test_and_clear_bit(RXRPC_CALL_EV_ACK, &call->events) || - send_ack) { - if (call->ackr_reason) { - rxrpc_send_ack_packet(call, false, send_ack); - goto recheck_state; - } - } + if (resend && __rxrpc_call_state(call) != RXRPC_CALL_CLIENT_RECV_REPLY) + rxrpc_resend(call, NULL); - if (test_and_clear_bit(RXRPC_CALL_EV_PING, &call->events)) { - rxrpc_send_ack_packet(call, true, NULL); - goto recheck_state; - } + if (test_and_clear_bit(RXRPC_CALL_RX_IS_IDLE, &call->flags)) + rxrpc_send_ACK(call, RXRPC_ACK_IDLE, 0, + rxrpc_propose_ack_rx_idle); - if (test_and_clear_bit(RXRPC_CALL_EV_RESEND, &call->events) && - call->state != RXRPC_CALL_CLIENT_RECV_REPLY) { - rxrpc_resend(call, now); - goto recheck_state; - } + if (atomic_read(&call->ackr_nr_unacked) > 2) + rxrpc_send_ACK(call, RXRPC_ACK_IDLE, 0, + rxrpc_propose_ack_input_data); /* Make sure the timer is restarted */ - next = call->expect_rx_by; + if (!__rxrpc_call_is_complete(call)) { + next = call->expect_rx_by; #define set(T) { t = READ_ONCE(T); if (time_before(t, next)) next = t; } - set(call->expect_req_by); - set(call->expect_term_by); - set(call->ack_at); - set(call->ack_lost_at); - set(call->resend_at); - set(call->keepalive_at); - set(call->ping_at); - - now = jiffies; - if (time_after_eq(now, next)) - goto recheck_state; + set(call->expect_req_by); + set(call->expect_term_by); + set(call->delay_ack_at); + set(call->ack_lost_at); + set(call->resend_at); + set(call->keepalive_at); + set(call->ping_at); - rxrpc_reduce_call_timer(call, next, now, rxrpc_timer_restart); + now = jiffies; + if (time_after_eq(now, next)) + rxrpc_poke_call(call, rxrpc_call_poke_timer_now); - /* other events may have been raised since we started checking */ - if (call->events && call->state < RXRPC_CALL_COMPLETE) - goto requeue; + rxrpc_reduce_call_timer(call, next, now, rxrpc_timer_restart); + } -out_put: - rxrpc_put_call(call, rxrpc_call_put); out: + if (__rxrpc_call_is_complete(call)) { + del_timer_sync(&call->timer); + if (!test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) + rxrpc_disconnect_call(call); + if (call->security) + call->security->free_call_crypto(call); + } + if (call->acks_hard_ack != call->tx_bottom) + rxrpc_shrink_call_tx_buffer(call); _leave(""); - return; - -requeue: - __rxrpc_queue_call(call); - goto out; + return true; } diff --git a/net/rxrpc/call_object.c b/net/rxrpc/call_object.c index 6401cdf7a624..f3c9f0201c15 100644 --- a/net/rxrpc/call_object.c +++ b/net/rxrpc/call_object.c @@ -45,17 +45,33 @@ static struct semaphore rxrpc_call_limiter = static struct semaphore rxrpc_kernel_call_limiter = __SEMAPHORE_INITIALIZER(rxrpc_kernel_call_limiter, 1000); +void rxrpc_poke_call(struct rxrpc_call *call, enum rxrpc_call_poke_trace what) +{ + struct rxrpc_local *local = call->local; + bool busy; + + if (!test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) { + spin_lock_bh(&local->lock); + busy = !list_empty(&call->attend_link); + trace_rxrpc_poke_call(call, busy, what); + if (!busy) { + rxrpc_get_call(call, rxrpc_call_get_poke); + list_add_tail(&call->attend_link, &local->call_attend_q); + } + spin_unlock_bh(&local->lock); + rxrpc_wake_up_io_thread(local); + } +} + static void rxrpc_call_timer_expired(struct timer_list *t) { struct rxrpc_call *call = from_timer(call, t, timer); _enter("%d", call->debug_id); - if (call->state < RXRPC_CALL_COMPLETE) { - trace_rxrpc_timer(call, rxrpc_timer_expired, jiffies); - __rxrpc_queue_call(call); - } else { - rxrpc_put_call(call, rxrpc_call_put); + if (!__rxrpc_call_is_complete(call)) { + trace_rxrpc_timer_expired(call, jiffies); + rxrpc_poke_call(call, rxrpc_call_poke_timer); } } @@ -64,21 +80,14 @@ void rxrpc_reduce_call_timer(struct rxrpc_call *call, unsigned long now, enum rxrpc_timer_trace why) { - if (rxrpc_try_get_call(call, rxrpc_call_got_timer)) { - trace_rxrpc_timer(call, why, now); - if (timer_reduce(&call->timer, expire_at)) - rxrpc_put_call(call, rxrpc_call_put_notimer); - } -} - -void rxrpc_delete_call_timer(struct rxrpc_call *call) -{ - if (del_timer_sync(&call->timer)) - rxrpc_put_call(call, rxrpc_call_put_timer); + trace_rxrpc_timer(call, why, now); + timer_reduce(&call->timer, expire_at); } static struct lock_class_key rxrpc_call_user_mutex_lock_class_key; +static void rxrpc_destroy_call(struct work_struct *); + /* * find an extant server call * - called in process context with IRQs enabled @@ -110,7 +119,7 @@ struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *rx, return NULL; found_extant_call: - rxrpc_get_call(call, rxrpc_call_got); + rxrpc_get_call(call, rxrpc_call_get_sendmsg); read_unlock(&rx->call_lock); _leave(" = %p [%d]", call, refcount_read(&call->ref)); return call; @@ -129,16 +138,6 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp, if (!call) return NULL; - call->rxtx_buffer = kcalloc(RXRPC_RXTX_BUFF_SIZE, - sizeof(struct sk_buff *), - gfp); - if (!call->rxtx_buffer) - goto nomem; - - call->rxtx_annotations = kcalloc(RXRPC_RXTX_BUFF_SIZE, sizeof(u8), gfp); - if (!call->rxtx_annotations) - goto nomem_2; - mutex_init(&call->user_mutex); /* Prevent lockdep reporting a deadlock false positive between the afs @@ -149,43 +148,44 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp, &rxrpc_call_user_mutex_lock_class_key); timer_setup(&call->timer, rxrpc_call_timer_expired, 0); - INIT_WORK(&call->processor, &rxrpc_process_call); + INIT_WORK(&call->destroyer, rxrpc_destroy_call); INIT_LIST_HEAD(&call->link); - INIT_LIST_HEAD(&call->chan_wait_link); + INIT_LIST_HEAD(&call->wait_link); INIT_LIST_HEAD(&call->accept_link); INIT_LIST_HEAD(&call->recvmsg_link); INIT_LIST_HEAD(&call->sock_link); + INIT_LIST_HEAD(&call->attend_link); + INIT_LIST_HEAD(&call->tx_sendmsg); + INIT_LIST_HEAD(&call->tx_buffer); + skb_queue_head_init(&call->recvmsg_queue); + skb_queue_head_init(&call->rx_oos_queue); init_waitqueue_head(&call->waitq); - spin_lock_init(&call->lock); spin_lock_init(&call->notify_lock); - spin_lock_init(&call->input_lock); - rwlock_init(&call->state_lock); + spin_lock_init(&call->tx_lock); refcount_set(&call->ref, 1); call->debug_id = debug_id; call->tx_total_len = -1; call->next_rx_timo = 20 * HZ; call->next_req_timo = 1 * HZ; + atomic64_set(&call->ackr_window, 0x100000001ULL); memset(&call->sock_node, 0xed, sizeof(call->sock_node)); - /* Leave space in the ring to handle a maxed-out jumbo packet */ call->rx_winsize = rxrpc_rx_window_size; call->tx_winsize = 16; - call->rx_expect_next = 1; - call->cong_cwnd = 2; - call->cong_ssthresh = RXRPC_RXTX_BUFF_SIZE - 1; + if (RXRPC_TX_SMSS > 2190) + call->cong_cwnd = 2; + else if (RXRPC_TX_SMSS > 1095) + call->cong_cwnd = 3; + else + call->cong_cwnd = 4; + call->cong_ssthresh = RXRPC_TX_MAX_WINDOW; call->rxnet = rxnet; call->rtt_avail = RXRPC_CALL_RTT_AVAIL_MASK; atomic_inc(&rxnet->nr_calls); return call; - -nomem_2: - kfree(call->rxtx_buffer); -nomem: - kmem_cache_free(rxrpc_call_jar, call); - return NULL; } /* @@ -193,23 +193,47 @@ nomem: */ static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx, struct sockaddr_rxrpc *srx, + struct rxrpc_conn_parameters *cp, + struct rxrpc_call_params *p, gfp_t gfp, unsigned int debug_id) { struct rxrpc_call *call; ktime_t now; + int ret; _enter(""); call = rxrpc_alloc_call(rx, gfp, debug_id); if (!call) return ERR_PTR(-ENOMEM); - call->state = RXRPC_CALL_CLIENT_AWAIT_CONN; - call->service_id = srx->srx_service; - call->tx_phase = true; now = ktime_get_real(); - call->acks_latest_ts = now; - call->cong_tstamp = now; + call->acks_latest_ts = now; + call->cong_tstamp = now; + call->dest_srx = *srx; + call->interruptibility = p->interruptibility; + call->tx_total_len = p->tx_total_len; + call->key = key_get(cp->key); + call->local = rxrpc_get_local(cp->local, rxrpc_local_get_call); + call->security_level = cp->security_level; + if (p->kernel) + __set_bit(RXRPC_CALL_KERNEL, &call->flags); + if (cp->upgrade) + __set_bit(RXRPC_CALL_UPGRADE, &call->flags); + if (cp->exclusive) + __set_bit(RXRPC_CALL_EXCLUSIVE, &call->flags); + + ret = rxrpc_init_client_call_security(call); + if (ret < 0) { + rxrpc_prefail_call(call, RXRPC_CALL_LOCAL_ERROR, ret); + rxrpc_put_call(call, rxrpc_call_put_discard_error); + return ERR_PTR(ret); + } + + rxrpc_set_call_state(call, RXRPC_CALL_CLIENT_AWAIT_CONN); + + trace_rxrpc_call(call->debug_id, refcount_read(&call->ref), + p->user_call_ID, rxrpc_call_new_client); _leave(" = %p", call); return call; @@ -218,15 +242,16 @@ static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx, /* * Initiate the call ack/resend/expiry timer. */ -static void rxrpc_start_call_timer(struct rxrpc_call *call) +void rxrpc_start_call_timer(struct rxrpc_call *call) { unsigned long now = jiffies; unsigned long j = now + MAX_JIFFY_OFFSET; - call->ack_at = j; + call->delay_ack_at = j; call->ack_lost_at = j; call->resend_at = j; call->ping_at = j; + call->keepalive_at = j; call->expect_rx_by = j; call->expect_req_by = j; call->expect_term_by = j; @@ -262,6 +287,39 @@ static void rxrpc_put_call_slot(struct rxrpc_call *call) } /* + * Start the process of connecting a call. We obtain a peer and a connection + * bundle, but the actual association of a call with a connection is offloaded + * to the I/O thread to simplify locking. + */ +static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp) +{ + struct rxrpc_local *local = call->local; + int ret = -ENOMEM; + + _enter("{%d,%lx},", call->debug_id, call->user_call_ID); + + call->peer = rxrpc_lookup_peer(local, &call->dest_srx, gfp); + if (!call->peer) + goto error; + + ret = rxrpc_look_up_bundle(call, gfp); + if (ret < 0) + goto error; + + trace_rxrpc_client(NULL, -1, rxrpc_client_queue_new_call); + rxrpc_get_call(call, rxrpc_call_get_io_thread); + spin_lock(&local->client_call_lock); + list_add_tail(&call->wait_link, &local->new_client_calls); + spin_unlock(&local->client_call_lock); + rxrpc_wake_up_io_thread(local); + return 0; + +error: + __set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); + return ret; +} + +/* * Set up a call for the given parameters. * - Called with the socket lock held, which it must release. * - If it returns a call, the call's lock will need releasing by the caller. @@ -279,7 +337,6 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, struct rxrpc_net *rxnet; struct semaphore *limiter; struct rb_node *parent, **pp; - const void *here = __builtin_return_address(0); int ret; _enter("%p,%lx", rx, p->user_call_ID); @@ -290,7 +347,7 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, return ERR_PTR(-ERESTARTSYS); } - call = rxrpc_alloc_client_call(rx, srx, gfp, debug_id); + call = rxrpc_alloc_client_call(rx, srx, cp, p, gfp, debug_id); if (IS_ERR(call)) { release_sock(&rx->sk); up(limiter); @@ -298,14 +355,6 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, return call; } - call->interruptibility = p->interruptibility; - call->tx_total_len = p->tx_total_len; - trace_rxrpc_call(call->debug_id, rxrpc_call_new_client, - refcount_read(&call->ref), - here, (const void *)p->user_call_ID); - if (p->kernel) - __set_bit(RXRPC_CALL_KERNEL, &call->flags); - /* We need to protect a partially set up call against the user as we * will be acting outside the socket lock. */ @@ -331,7 +380,7 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, rcu_assign_pointer(call->socket, rx); call->user_call_ID = p->user_call_ID; __set_bit(RXRPC_CALL_HAS_USERID, &call->flags); - rxrpc_get_call(call, rxrpc_call_got_userid); + rxrpc_get_call(call, rxrpc_call_get_userid); rb_link_node(&call->sock_node, parent, pp); rb_insert_color(&call->sock_node, &rx->calls); list_add(&call->sock_link, &rx->sock_calls); @@ -339,9 +388,9 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, write_unlock(&rx->call_lock); rxnet = call->rxnet; - spin_lock_bh(&rxnet->call_lock); + spin_lock(&rxnet->call_lock); list_add_tail_rcu(&call->link, &rxnet->calls); - spin_unlock_bh(&rxnet->call_lock); + spin_unlock(&rxnet->call_lock); /* From this point on, the call is protected by its own lock. */ release_sock(&rx->sk); @@ -349,17 +398,10 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, /* Set up or get a connection record and set the protocol parameters, * including channel number and call ID. */ - ret = rxrpc_connect_call(rx, call, cp, srx, gfp); + ret = rxrpc_connect_call(call, gfp); if (ret < 0) goto error_attached_to_socket; - trace_rxrpc_call(call->debug_id, rxrpc_call_connected, - refcount_read(&call->ref), here, NULL); - - rxrpc_start_call_timer(call); - - _net("CALL new %d on CONN %d", call->debug_id, call->conn->debug_id); - _leave(" = %p [new]", call); return call; @@ -371,27 +413,23 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, error_dup_user_ID: write_unlock(&rx->call_lock); release_sock(&rx->sk); - __rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, - RX_CALL_DEAD, -EEXIST); - trace_rxrpc_call(call->debug_id, rxrpc_call_error, - refcount_read(&call->ref), here, ERR_PTR(-EEXIST)); - rxrpc_release_call(rx, call); + rxrpc_prefail_call(call, RXRPC_CALL_LOCAL_ERROR, -EEXIST); + trace_rxrpc_call(call->debug_id, refcount_read(&call->ref), 0, + rxrpc_call_see_userid_exists); mutex_unlock(&call->user_mutex); - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_userid_exists); _leave(" = -EEXIST"); return ERR_PTR(-EEXIST); /* We got an error, but the call is attached to the socket and is in - * need of release. However, we might now race with recvmsg() when - * completing the call queues it. Return 0 from sys_sendmsg() and + * need of release. However, we might now race with recvmsg() when it + * completion notifies the socket. Return 0 from sys_sendmsg() and * leave the error to recvmsg() to deal with. */ error_attached_to_socket: - trace_rxrpc_call(call->debug_id, rxrpc_call_error, - refcount_read(&call->ref), here, ERR_PTR(ret)); - set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); - __rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, - RX_CALL_DEAD, ret); + trace_rxrpc_call(call->debug_id, refcount_read(&call->ref), ret, + rxrpc_call_see_connect_failed); + rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, 0, ret); _leave(" = c=%08x [err]", call->debug_id); return call; } @@ -412,11 +450,34 @@ void rxrpc_incoming_call(struct rxrpc_sock *rx, rcu_assign_pointer(call->socket, rx); call->call_id = sp->hdr.callNumber; - call->service_id = sp->hdr.serviceId; + call->dest_srx.srx_service = sp->hdr.serviceId; call->cid = sp->hdr.cid; - call->state = RXRPC_CALL_SERVER_SECURING; call->cong_tstamp = skb->tstamp; + __set_bit(RXRPC_CALL_EXPOSED, &call->flags); + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_SECURING); + + spin_lock(&conn->state_lock); + + switch (conn->state) { + case RXRPC_CONN_SERVICE_UNSECURED: + case RXRPC_CONN_SERVICE_CHALLENGING: + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_SECURING); + break; + case RXRPC_CONN_SERVICE: + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_RECV_REQUEST); + break; + + case RXRPC_CONN_ABORTED: + rxrpc_set_call_completion(call, conn->completion, + conn->abort_code, conn->error); + break; + default: + BUG(); + } + + rxrpc_get_call(call, rxrpc_call_get_io_thread); + /* Set the channel for this call. We don't get channel_lock as we're * only defending against the data_ready handler (which we're called * from) and the RESPONSE packet parser (which is only really @@ -426,100 +487,58 @@ void rxrpc_incoming_call(struct rxrpc_sock *rx, chan = sp->hdr.cid & RXRPC_CHANNELMASK; conn->channels[chan].call_counter = call->call_id; conn->channels[chan].call_id = call->call_id; - rcu_assign_pointer(conn->channels[chan].call, call); + conn->channels[chan].call = call; + spin_unlock(&conn->state_lock); - spin_lock(&conn->params.peer->lock); - hlist_add_head_rcu(&call->error_link, &conn->params.peer->error_targets); - spin_unlock(&conn->params.peer->lock); - - _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id); + spin_lock(&conn->peer->lock); + hlist_add_head(&call->error_link, &conn->peer->error_targets); + spin_unlock(&conn->peer->lock); rxrpc_start_call_timer(call); _leave(""); } /* - * Queue a call's work processor, getting a ref to pass to the work queue. - */ -bool rxrpc_queue_call(struct rxrpc_call *call) -{ - const void *here = __builtin_return_address(0); - int n; - - if (!__refcount_inc_not_zero(&call->ref, &n)) - return false; - if (rxrpc_queue_work(&call->processor)) - trace_rxrpc_call(call->debug_id, rxrpc_call_queued, n + 1, - here, NULL); - else - rxrpc_put_call(call, rxrpc_call_put_noqueue); - return true; -} - -/* - * Queue a call's work processor, passing the callers ref to the work queue. - */ -bool __rxrpc_queue_call(struct rxrpc_call *call) -{ - const void *here = __builtin_return_address(0); - int n = refcount_read(&call->ref); - ASSERTCMP(n, >=, 1); - if (rxrpc_queue_work(&call->processor)) - trace_rxrpc_call(call->debug_id, rxrpc_call_queued_ref, n, - here, NULL); - else - rxrpc_put_call(call, rxrpc_call_put_noqueue); - return true; -} - -/* * Note the re-emergence of a call. */ -void rxrpc_see_call(struct rxrpc_call *call) +void rxrpc_see_call(struct rxrpc_call *call, enum rxrpc_call_trace why) { - const void *here = __builtin_return_address(0); if (call) { - int n = refcount_read(&call->ref); + int r = refcount_read(&call->ref); - trace_rxrpc_call(call->debug_id, rxrpc_call_seen, n, - here, NULL); + trace_rxrpc_call(call->debug_id, r, 0, why); } } -bool rxrpc_try_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +struct rxrpc_call *rxrpc_try_get_call(struct rxrpc_call *call, + enum rxrpc_call_trace why) { - const void *here = __builtin_return_address(0); - int n; + int r; - if (!__refcount_inc_not_zero(&call->ref, &n)) - return false; - trace_rxrpc_call(call->debug_id, op, n + 1, here, NULL); - return true; + if (!call || !__refcount_inc_not_zero(&call->ref, &r)) + return NULL; + trace_rxrpc_call(call->debug_id, r + 1, 0, why); + return call; } /* * Note the addition of a ref on a call. */ -void rxrpc_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +void rxrpc_get_call(struct rxrpc_call *call, enum rxrpc_call_trace why) { - const void *here = __builtin_return_address(0); - int n; + int r; - __refcount_inc(&call->ref, &n); - trace_rxrpc_call(call->debug_id, op, n + 1, here, NULL); + __refcount_inc(&call->ref, &r); + trace_rxrpc_call(call->debug_id, r + 1, 0, why); } /* - * Clean up the RxTx skb ring. + * Clean up the Rx skb ring. */ static void rxrpc_cleanup_ring(struct rxrpc_call *call) { - int i; - - for (i = 0; i < RXRPC_RXTX_BUFF_SIZE; i++) { - rxrpc_free_skb(call->rxtx_buffer[i], rxrpc_skb_cleaned); - call->rxtx_buffer[i] = NULL; - } + skb_queue_purge(&call->recvmsg_queue); + skb_queue_purge(&call->rx_oos_queue); } /* @@ -527,28 +546,21 @@ static void rxrpc_cleanup_ring(struct rxrpc_call *call) */ void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call) { - const void *here = __builtin_return_address(0); struct rxrpc_connection *conn = call->conn; - bool put = false; + bool put = false, putu = false; _enter("{%d,%d}", call->debug_id, refcount_read(&call->ref)); - trace_rxrpc_call(call->debug_id, rxrpc_call_release, - refcount_read(&call->ref), - here, (const void *)call->flags); + trace_rxrpc_call(call->debug_id, refcount_read(&call->ref), + call->flags, rxrpc_call_see_release); - ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); - - spin_lock_bh(&call->lock); if (test_and_set_bit(RXRPC_CALL_RELEASED, &call->flags)) BUG(); - spin_unlock_bh(&call->lock); rxrpc_put_call_slot(call); - rxrpc_delete_call_timer(call); /* Make sure we don't get any more notifications */ - write_lock_bh(&rx->recvmsg_lock); + write_lock(&rx->recvmsg_lock); if (!list_empty(&call->recvmsg_link)) { _debug("unlinking once-pending call %p { e=%lx f=%lx }", @@ -561,16 +573,16 @@ void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call) call->recvmsg_link.next = NULL; call->recvmsg_link.prev = NULL; - write_unlock_bh(&rx->recvmsg_lock); + write_unlock(&rx->recvmsg_lock); if (put) - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_unnotify); write_lock(&rx->call_lock); if (test_and_clear_bit(RXRPC_CALL_HAS_USERID, &call->flags)) { rb_erase(&call->sock_node, &rx->calls); memset(&call->sock_node, 0xdd, sizeof(call->sock_node)); - rxrpc_put_call(call, rxrpc_call_put_userid); + putu = true; } list_del(&call->sock_link); @@ -578,10 +590,9 @@ void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call) _debug("RELEASE CALL %p (%d CONN %p)", call, call->debug_id, conn); - if (conn && !test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) - rxrpc_disconnect_call(call); - if (call->security) - call->security->free_call_crypto(call); + if (putu) + rxrpc_put_call(call, rxrpc_call_put_userid); + _leave(""); } @@ -598,18 +609,19 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx) call = list_entry(rx->to_be_accepted.next, struct rxrpc_call, accept_link); list_del(&call->accept_link); - rxrpc_abort_call("SKR", call, 0, RX_CALL_DEAD, -ECONNRESET); - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_propose_abort(call, RX_CALL_DEAD, -ECONNRESET, + rxrpc_abort_call_sock_release_tba); + rxrpc_put_call(call, rxrpc_call_put_release_sock_tba); } while (!list_empty(&rx->sock_calls)) { call = list_entry(rx->sock_calls.next, struct rxrpc_call, sock_link); - rxrpc_get_call(call, rxrpc_call_got); - rxrpc_abort_call("SKT", call, 0, RX_CALL_DEAD, -ECONNRESET); - rxrpc_send_abort_packet(call); + rxrpc_get_call(call, rxrpc_call_get_release_sock); + rxrpc_propose_abort(call, RX_CALL_DEAD, -ECONNRESET, + rxrpc_abort_call_sock_release); rxrpc_release_call(rx, call); - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_release_sock); } _leave(""); @@ -618,26 +630,24 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx) /* * release a call */ -void rxrpc_put_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +void rxrpc_put_call(struct rxrpc_call *call, enum rxrpc_call_trace why) { struct rxrpc_net *rxnet = call->rxnet; - const void *here = __builtin_return_address(0); unsigned int debug_id = call->debug_id; bool dead; - int n; + int r; ASSERT(call != NULL); - dead = __refcount_dec_and_test(&call->ref, &n); - trace_rxrpc_call(debug_id, op, n, here, NULL); + dead = __refcount_dec_and_test(&call->ref, &r); + trace_rxrpc_call(debug_id, r - 1, 0, why); if (dead) { - _debug("call %d dead", call->debug_id); - ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); + ASSERTCMP(__rxrpc_call_state(call), ==, RXRPC_CALL_COMPLETE); if (!list_empty(&call->link)) { - spin_lock_bh(&rxnet->call_lock); + spin_lock(&rxnet->call_lock); list_del_init(&call->link); - spin_unlock_bh(&rxnet->call_lock); + spin_unlock(&rxnet->call_lock); } rxrpc_cleanup_call(call); @@ -645,38 +655,47 @@ void rxrpc_put_call(struct rxrpc_call *call, enum rxrpc_call_trace op) } /* - * Final call destruction - but must be done in process context. + * Free up the call under RCU. */ -static void rxrpc_destroy_call(struct work_struct *work) +static void rxrpc_rcu_free_call(struct rcu_head *rcu) { - struct rxrpc_call *call = container_of(work, struct rxrpc_call, processor); - struct rxrpc_net *rxnet = call->rxnet; - - rxrpc_delete_call_timer(call); + struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu); + struct rxrpc_net *rxnet = READ_ONCE(call->rxnet); - rxrpc_put_connection(call->conn); - rxrpc_put_peer(call->peer); - kfree(call->rxtx_buffer); - kfree(call->rxtx_annotations); kmem_cache_free(rxrpc_call_jar, call); if (atomic_dec_and_test(&rxnet->nr_calls)) wake_up_var(&rxnet->nr_calls); } /* - * Final call destruction under RCU. + * Final call destruction - but must be done in process context. */ -static void rxrpc_rcu_destroy_call(struct rcu_head *rcu) +static void rxrpc_destroy_call(struct work_struct *work) { - struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu); + struct rxrpc_call *call = container_of(work, struct rxrpc_call, destroyer); + struct rxrpc_txbuf *txb; + + del_timer_sync(&call->timer); - if (in_softirq()) { - INIT_WORK(&call->processor, rxrpc_destroy_call); - if (!rxrpc_queue_work(&call->processor)) - BUG(); - } else { - rxrpc_destroy_call(&call->processor); + rxrpc_cleanup_ring(call); + while ((txb = list_first_entry_or_null(&call->tx_sendmsg, + struct rxrpc_txbuf, call_link))) { + list_del(&txb->call_link); + rxrpc_put_txbuf(txb, rxrpc_txbuf_put_cleaned); + } + while ((txb = list_first_entry_or_null(&call->tx_buffer, + struct rxrpc_txbuf, call_link))) { + list_del(&txb->call_link); + rxrpc_put_txbuf(txb, rxrpc_txbuf_put_cleaned); } + + rxrpc_put_txbuf(call->tx_pending, rxrpc_txbuf_put_cleaned); + rxrpc_put_connection(call->conn, rxrpc_conn_put_call); + rxrpc_deactivate_bundle(call->bundle); + rxrpc_put_bundle(call->bundle, rxrpc_bundle_put_call); + rxrpc_put_peer(call->peer, rxrpc_peer_put_call); + rxrpc_put_local(call->local, rxrpc_local_put_call); + call_rcu(&call->rcu, rxrpc_rcu_free_call); } /* @@ -684,17 +703,20 @@ static void rxrpc_rcu_destroy_call(struct rcu_head *rcu) */ void rxrpc_cleanup_call(struct rxrpc_call *call) { - _net("DESTROY CALL %d", call->debug_id); - memset(&call->sock_node, 0xcd, sizeof(call->sock_node)); - ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); + ASSERTCMP(__rxrpc_call_state(call), ==, RXRPC_CALL_COMPLETE); ASSERT(test_bit(RXRPC_CALL_RELEASED, &call->flags)); - rxrpc_cleanup_ring(call); - rxrpc_free_skb(call->tx_pending, rxrpc_skb_cleaned); + del_timer(&call->timer); - call_rcu(&call->rcu, rxrpc_rcu_destroy_call); + if (rcu_read_lock_held()) + /* Can't use the rxrpc workqueue as we need to cancel/flush + * something that may be running/waiting there. + */ + schedule_work(&call->destroyer); + else + rxrpc_destroy_call(&call->destroyer); } /* @@ -709,27 +731,27 @@ void rxrpc_destroy_all_calls(struct rxrpc_net *rxnet) _enter(""); if (!list_empty(&rxnet->calls)) { - spin_lock_bh(&rxnet->call_lock); + spin_lock(&rxnet->call_lock); while (!list_empty(&rxnet->calls)) { call = list_entry(rxnet->calls.next, struct rxrpc_call, link); _debug("Zapping call %p", call); - rxrpc_see_call(call); + rxrpc_see_call(call, rxrpc_call_see_zap); list_del_init(&call->link); pr_err("Call %p still in use (%d,%s,%lx,%lx)!\n", call, refcount_read(&call->ref), - rxrpc_call_states[call->state], + rxrpc_call_states[__rxrpc_call_state(call)], call->flags, call->events); - spin_unlock_bh(&rxnet->call_lock); + spin_unlock(&rxnet->call_lock); cond_resched(); - spin_lock_bh(&rxnet->call_lock); + spin_lock(&rxnet->call_lock); } - spin_unlock_bh(&rxnet->call_lock); + spin_unlock(&rxnet->call_lock); } atomic_dec(&rxnet->nr_calls); diff --git a/net/rxrpc/call_state.c b/net/rxrpc/call_state.c new file mode 100644 index 000000000000..6afb54373ebb --- /dev/null +++ b/net/rxrpc/call_state.c @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Call state changing functions. + * + * Copyright (C) 2022 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include "ar-internal.h" + +/* + * Transition a call to the complete state. + */ +bool rxrpc_set_call_completion(struct rxrpc_call *call, + enum rxrpc_call_completion compl, + u32 abort_code, + int error) +{ + if (__rxrpc_call_state(call) == RXRPC_CALL_COMPLETE) + return false; + + call->abort_code = abort_code; + call->error = error; + call->completion = compl; + /* Allow reader of completion state to operate locklessly */ + rxrpc_set_call_state(call, RXRPC_CALL_COMPLETE); + trace_rxrpc_call_complete(call); + wake_up(&call->waitq); + rxrpc_notify_socket(call); + return true; +} + +/* + * Record that a call successfully completed. + */ +bool rxrpc_call_completed(struct rxrpc_call *call) +{ + return rxrpc_set_call_completion(call, RXRPC_CALL_SUCCEEDED, 0, 0); +} + +/* + * Record that a call is locally aborted. + */ +bool rxrpc_abort_call(struct rxrpc_call *call, rxrpc_seq_t seq, + u32 abort_code, int error, enum rxrpc_abort_reason why) +{ + trace_rxrpc_abort(call->debug_id, why, call->cid, call->call_id, seq, + abort_code, error); + if (!rxrpc_set_call_completion(call, RXRPC_CALL_LOCALLY_ABORTED, + abort_code, error)) + return false; + if (test_bit(RXRPC_CALL_EXPOSED, &call->flags)) + rxrpc_send_abort_packet(call); + return true; +} + +/* + * Record that a call errored out before even getting off the ground, thereby + * setting the state to allow it to be destroyed. + */ +void rxrpc_prefail_call(struct rxrpc_call *call, enum rxrpc_call_completion compl, + int error) +{ + call->abort_code = RX_CALL_DEAD; + call->error = error; + call->completion = compl; + call->_state = RXRPC_CALL_COMPLETE; + trace_rxrpc_call_complete(call); + WARN_ON_ONCE(__test_and_set_bit(RXRPC_CALL_RELEASED, &call->flags)); +} diff --git a/net/rxrpc/conn_client.c b/net/rxrpc/conn_client.c index 3c9eeb5b750c..981ca5b98bcb 100644 --- a/net/rxrpc/conn_client.c +++ b/net/rxrpc/conn_client.c @@ -34,184 +34,160 @@ __read_mostly unsigned int rxrpc_reap_client_connections = 900; __read_mostly unsigned long rxrpc_conn_idle_client_expiry = 2 * 60 * HZ; __read_mostly unsigned long rxrpc_conn_idle_client_fast_expiry = 2 * HZ; -/* - * We use machine-unique IDs for our client connections. - */ -DEFINE_IDR(rxrpc_client_conn_ids); -static DEFINE_SPINLOCK(rxrpc_conn_id_lock); - -/* - * Get a connection ID and epoch for a client connection from the global pool. - * The connection struct pointer is then recorded in the idr radix tree. The - * epoch doesn't change until the client is rebooted (or, at least, unless the - * module is unloaded). - */ -static int rxrpc_get_client_connection_id(struct rxrpc_connection *conn, - gfp_t gfp) +static void rxrpc_activate_bundle(struct rxrpc_bundle *bundle) { - struct rxrpc_net *rxnet = conn->params.local->rxnet; - int id; - - _enter(""); - - idr_preload(gfp); - spin_lock(&rxrpc_conn_id_lock); - - id = idr_alloc_cyclic(&rxrpc_client_conn_ids, conn, - 1, 0x40000000, GFP_NOWAIT); - if (id < 0) - goto error; - - spin_unlock(&rxrpc_conn_id_lock); - idr_preload_end(); - - conn->proto.epoch = rxnet->epoch; - conn->proto.cid = id << RXRPC_CIDSHIFT; - set_bit(RXRPC_CONN_HAS_IDR, &conn->flags); - _leave(" [CID %x]", conn->proto.cid); - return 0; - -error: - spin_unlock(&rxrpc_conn_id_lock); - idr_preload_end(); - _leave(" = %d", id); - return id; + atomic_inc(&bundle->active); } /* - * Release a connection ID for a client connection from the global pool. + * Release a connection ID for a client connection. */ -static void rxrpc_put_client_connection_id(struct rxrpc_connection *conn) +static void rxrpc_put_client_connection_id(struct rxrpc_local *local, + struct rxrpc_connection *conn) { - if (test_bit(RXRPC_CONN_HAS_IDR, &conn->flags)) { - spin_lock(&rxrpc_conn_id_lock); - idr_remove(&rxrpc_client_conn_ids, - conn->proto.cid >> RXRPC_CIDSHIFT); - spin_unlock(&rxrpc_conn_id_lock); - } + idr_remove(&local->conn_ids, conn->proto.cid >> RXRPC_CIDSHIFT); } /* * Destroy the client connection ID tree. */ -void rxrpc_destroy_client_conn_ids(void) +static void rxrpc_destroy_client_conn_ids(struct rxrpc_local *local) { struct rxrpc_connection *conn; int id; - if (!idr_is_empty(&rxrpc_client_conn_ids)) { - idr_for_each_entry(&rxrpc_client_conn_ids, conn, id) { + if (!idr_is_empty(&local->conn_ids)) { + idr_for_each_entry(&local->conn_ids, conn, id) { pr_err("AF_RXRPC: Leaked client conn %p {%d}\n", conn, refcount_read(&conn->ref)); } BUG(); } - idr_destroy(&rxrpc_client_conn_ids); + idr_destroy(&local->conn_ids); } /* * Allocate a connection bundle. */ -static struct rxrpc_bundle *rxrpc_alloc_bundle(struct rxrpc_conn_parameters *cp, +static struct rxrpc_bundle *rxrpc_alloc_bundle(struct rxrpc_call *call, gfp_t gfp) { struct rxrpc_bundle *bundle; bundle = kzalloc(sizeof(*bundle), gfp); if (bundle) { - bundle->params = *cp; - rxrpc_get_peer(bundle->params.peer); + bundle->local = call->local; + bundle->peer = rxrpc_get_peer(call->peer, rxrpc_peer_get_bundle); + bundle->key = key_get(call->key); + bundle->security = call->security; + bundle->exclusive = test_bit(RXRPC_CALL_EXCLUSIVE, &call->flags); + bundle->upgrade = test_bit(RXRPC_CALL_UPGRADE, &call->flags); + bundle->service_id = call->dest_srx.srx_service; + bundle->security_level = call->security_level; refcount_set(&bundle->ref, 1); - spin_lock_init(&bundle->channel_lock); + atomic_set(&bundle->active, 1); INIT_LIST_HEAD(&bundle->waiting_calls); + trace_rxrpc_bundle(bundle->debug_id, 1, rxrpc_bundle_new); } return bundle; } -struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *bundle) +struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *bundle, + enum rxrpc_bundle_trace why) { - refcount_inc(&bundle->ref); + int r; + + __refcount_inc(&bundle->ref, &r); + trace_rxrpc_bundle(bundle->debug_id, r + 1, why); return bundle; } static void rxrpc_free_bundle(struct rxrpc_bundle *bundle) { - rxrpc_put_peer(bundle->params.peer); + trace_rxrpc_bundle(bundle->debug_id, 1, rxrpc_bundle_free); + rxrpc_put_peer(bundle->peer, rxrpc_peer_put_bundle); + key_put(bundle->key); kfree(bundle); } -void rxrpc_put_bundle(struct rxrpc_bundle *bundle) +void rxrpc_put_bundle(struct rxrpc_bundle *bundle, enum rxrpc_bundle_trace why) { - unsigned int d = bundle->debug_id; + unsigned int id; bool dead; int r; - dead = __refcount_dec_and_test(&bundle->ref, &r); + if (bundle) { + id = bundle->debug_id; + dead = __refcount_dec_and_test(&bundle->ref, &r); + trace_rxrpc_bundle(id, r - 1, why); + if (dead) + rxrpc_free_bundle(bundle); + } +} - _debug("PUT B=%x %d", d, r); - if (dead) - rxrpc_free_bundle(bundle); +/* + * Get rid of outstanding client connection preallocations when a local + * endpoint is destroyed. + */ +void rxrpc_purge_client_connections(struct rxrpc_local *local) +{ + rxrpc_destroy_client_conn_ids(local); } /* * Allocate a client connection. */ static struct rxrpc_connection * -rxrpc_alloc_client_connection(struct rxrpc_bundle *bundle, gfp_t gfp) +rxrpc_alloc_client_connection(struct rxrpc_bundle *bundle) { struct rxrpc_connection *conn; - struct rxrpc_net *rxnet = bundle->params.local->rxnet; - int ret; + struct rxrpc_local *local = bundle->local; + struct rxrpc_net *rxnet = local->rxnet; + int id; _enter(""); - conn = rxrpc_alloc_connection(gfp); - if (!conn) { - _leave(" = -ENOMEM"); + conn = rxrpc_alloc_connection(rxnet, GFP_ATOMIC | __GFP_NOWARN); + if (!conn) return ERR_PTR(-ENOMEM); + + id = idr_alloc_cyclic(&local->conn_ids, conn, 1, 0x40000000, + GFP_ATOMIC | __GFP_NOWARN); + if (id < 0) { + kfree(conn); + return ERR_PTR(id); } refcount_set(&conn->ref, 1); - conn->bundle = bundle; - conn->params = bundle->params; + conn->proto.cid = id << RXRPC_CIDSHIFT; + conn->proto.epoch = local->rxnet->epoch; conn->out_clientflag = RXRPC_CLIENT_INITIATED; - conn->state = RXRPC_CONN_CLIENT; - conn->service_id = conn->params.service_id; - - ret = rxrpc_get_client_connection_id(conn, gfp); - if (ret < 0) - goto error_0; - - ret = rxrpc_init_client_conn_security(conn); - if (ret < 0) - goto error_1; + conn->bundle = rxrpc_get_bundle(bundle, rxrpc_bundle_get_client_conn); + conn->local = rxrpc_get_local(bundle->local, rxrpc_local_get_client_conn); + conn->peer = rxrpc_get_peer(bundle->peer, rxrpc_peer_get_client_conn); + conn->key = key_get(bundle->key); + conn->security = bundle->security; + conn->exclusive = bundle->exclusive; + conn->upgrade = bundle->upgrade; + conn->orig_service_id = bundle->service_id; + conn->security_level = bundle->security_level; + conn->state = RXRPC_CONN_CLIENT_UNSECURED; + conn->service_id = conn->orig_service_id; + + if (conn->security == &rxrpc_no_security) + conn->state = RXRPC_CONN_CLIENT; atomic_inc(&rxnet->nr_conns); write_lock(&rxnet->conn_lock); list_add_tail(&conn->proc_link, &rxnet->conn_proc_list); write_unlock(&rxnet->conn_lock); - rxrpc_get_bundle(bundle); - rxrpc_get_peer(conn->params.peer); - rxrpc_get_local(conn->params.local); - key_get(conn->params.key); - - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_client, - refcount_read(&conn->ref), - __builtin_return_address(0)); + rxrpc_see_connection(conn, rxrpc_conn_new_client); atomic_inc(&rxnet->nr_client_conns); trace_rxrpc_client(conn, -1, rxrpc_client_alloc); - _leave(" = %p", conn); return conn; - -error_1: - rxrpc_put_client_connection_id(conn); -error_0: - kfree(conn); - _leave(" = %d", ret); - return ERR_PTR(ret); } /* @@ -225,11 +201,12 @@ static bool rxrpc_may_reuse_conn(struct rxrpc_connection *conn) if (!conn) goto dont_reuse; - rxnet = conn->params.local->rxnet; + rxnet = conn->rxnet; if (test_bit(RXRPC_CONN_DONT_REUSE, &conn->flags)) goto dont_reuse; - if (conn->state != RXRPC_CONN_CLIENT || + if ((conn->state != RXRPC_CONN_CLIENT_UNSECURED && + conn->state != RXRPC_CONN_CLIENT) || conn->proto.epoch != rxnet->epoch) goto mark_dont_reuse; @@ -239,7 +216,7 @@ static bool rxrpc_may_reuse_conn(struct rxrpc_connection *conn) * times the maximum number of client conns away from the current * allocation point to try and keep the IDs concentrated. */ - id_cursor = idr_get_cursor(&rxrpc_client_conn_ids); + id_cursor = idr_get_cursor(&conn->local->conn_ids); id = conn->proto.cid >> RXRPC_CIDSHIFT; distance = id - id_cursor; if (distance < 0) @@ -260,20 +237,23 @@ dont_reuse: * Look up the conn bundle that matches the connection parameters, adding it if * it doesn't yet exist. */ -static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *cp, - gfp_t gfp) +int rxrpc_look_up_bundle(struct rxrpc_call *call, gfp_t gfp) { static atomic_t rxrpc_bundle_id; struct rxrpc_bundle *bundle, *candidate; - struct rxrpc_local *local = cp->local; + struct rxrpc_local *local = call->local; struct rb_node *p, **pp, *parent; long diff; + bool upgrade = test_bit(RXRPC_CALL_UPGRADE, &call->flags); _enter("{%px,%x,%u,%u}", - cp->peer, key_serial(cp->key), cp->security_level, cp->upgrade); + call->peer, key_serial(call->key), call->security_level, + upgrade); - if (cp->exclusive) - return rxrpc_alloc_bundle(cp, gfp); + if (test_bit(RXRPC_CALL_EXCLUSIVE, &call->flags)) { + call->bundle = rxrpc_alloc_bundle(call, gfp); + return call->bundle ? 0 : -ENOMEM; + } /* First, see if the bundle is already there. */ _debug("search 1"); @@ -282,11 +262,11 @@ static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *c while (p) { bundle = rb_entry(p, struct rxrpc_bundle, local_node); -#define cmp(X) ((long)bundle->params.X - (long)cp->X) - diff = (cmp(peer) ?: - cmp(key) ?: - cmp(security_level) ?: - cmp(upgrade)); +#define cmp(X, Y) ((long)(X) - (long)(Y)) + diff = (cmp(bundle->peer, call->peer) ?: + cmp(bundle->key, call->key) ?: + cmp(bundle->security_level, call->security_level) ?: + cmp(bundle->upgrade, upgrade)); #undef cmp if (diff < 0) p = p->rb_left; @@ -299,9 +279,9 @@ static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *c _debug("not found"); /* It wasn't. We need to add one. */ - candidate = rxrpc_alloc_bundle(cp, gfp); + candidate = rxrpc_alloc_bundle(call, gfp); if (!candidate) - return NULL; + return -ENOMEM; _debug("search 2"); spin_lock(&local->client_bundles_lock); @@ -311,11 +291,11 @@ static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *c parent = *pp; bundle = rb_entry(parent, struct rxrpc_bundle, local_node); -#define cmp(X) ((long)bundle->params.X - (long)cp->X) - diff = (cmp(peer) ?: - cmp(key) ?: - cmp(security_level) ?: - cmp(upgrade)); +#define cmp(X, Y) ((long)(X) - (long)(Y)) + diff = (cmp(bundle->peer, call->peer) ?: + cmp(bundle->key, call->key) ?: + cmp(bundle->security_level, call->security_level) ?: + cmp(bundle->upgrade, upgrade)); #undef cmp if (diff < 0) pp = &(*pp)->rb_left; @@ -329,175 +309,89 @@ static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *c candidate->debug_id = atomic_inc_return(&rxrpc_bundle_id); rb_link_node(&candidate->local_node, parent, pp); rb_insert_color(&candidate->local_node, &local->client_bundles); - rxrpc_get_bundle(candidate); + call->bundle = rxrpc_get_bundle(candidate, rxrpc_bundle_get_client_call); spin_unlock(&local->client_bundles_lock); - _leave(" = %u [new]", candidate->debug_id); - return candidate; + _leave(" = B=%u [new]", call->bundle->debug_id); + return 0; found_bundle_free: rxrpc_free_bundle(candidate); found_bundle: - rxrpc_get_bundle(bundle); + call->bundle = rxrpc_get_bundle(bundle, rxrpc_bundle_get_client_call); + rxrpc_activate_bundle(bundle); spin_unlock(&local->client_bundles_lock); - _leave(" = %u [found]", bundle->debug_id); - return bundle; -} - -/* - * Create or find a client bundle to use for a call. - * - * If we return with a connection, the call will be on its waiting list. It's - * left to the caller to assign a channel and wake up the call. - */ -static struct rxrpc_bundle *rxrpc_prep_call(struct rxrpc_sock *rx, - struct rxrpc_call *call, - struct rxrpc_conn_parameters *cp, - struct sockaddr_rxrpc *srx, - gfp_t gfp) -{ - struct rxrpc_bundle *bundle; - - _enter("{%d,%lx},", call->debug_id, call->user_call_ID); - - cp->peer = rxrpc_lookup_peer(rx, cp->local, srx, gfp); - if (!cp->peer) - goto error; - - call->cong_cwnd = cp->peer->cong_cwnd; - if (call->cong_cwnd >= call->cong_ssthresh) - call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; - else - call->cong_mode = RXRPC_CALL_SLOW_START; - if (cp->upgrade) - __set_bit(RXRPC_CALL_UPGRADE, &call->flags); - - /* Find the client connection bundle. */ - bundle = rxrpc_look_up_bundle(cp, gfp); - if (!bundle) - goto error; - - /* Get this call queued. Someone else may activate it whilst we're - * lining up a new connection, but that's fine. - */ - spin_lock(&bundle->channel_lock); - list_add_tail(&call->chan_wait_link, &bundle->waiting_calls); - spin_unlock(&bundle->channel_lock); - - _leave(" = [B=%x]", bundle->debug_id); - return bundle; - -error: - _leave(" = -ENOMEM"); - return ERR_PTR(-ENOMEM); + _leave(" = B=%u [found]", call->bundle->debug_id); + return 0; } /* * Allocate a new connection and add it into a bundle. */ -static void rxrpc_add_conn_to_bundle(struct rxrpc_bundle *bundle, gfp_t gfp) - __releases(bundle->channel_lock) +static bool rxrpc_add_conn_to_bundle(struct rxrpc_bundle *bundle, + unsigned int slot) { - struct rxrpc_connection *candidate = NULL, *old = NULL; - bool conflict; - int i; - - _enter(""); + struct rxrpc_connection *conn, *old; + unsigned int shift = slot * RXRPC_MAXCALLS; + unsigned int i; - conflict = bundle->alloc_conn; - if (!conflict) - bundle->alloc_conn = true; - spin_unlock(&bundle->channel_lock); - if (conflict) { - _leave(" [conf]"); - return; + old = bundle->conns[slot]; + if (old) { + bundle->conns[slot] = NULL; + trace_rxrpc_client(old, -1, rxrpc_client_replace); + rxrpc_put_connection(old, rxrpc_conn_put_noreuse); } - candidate = rxrpc_alloc_client_connection(bundle, gfp); - - spin_lock(&bundle->channel_lock); - bundle->alloc_conn = false; - - if (IS_ERR(candidate)) { - bundle->alloc_error = PTR_ERR(candidate); - spin_unlock(&bundle->channel_lock); - _leave(" [err %ld]", PTR_ERR(candidate)); - return; - } - - bundle->alloc_error = 0; - - for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) { - unsigned int shift = i * RXRPC_MAXCALLS; - int j; - - old = bundle->conns[i]; - if (!rxrpc_may_reuse_conn(old)) { - if (old) - trace_rxrpc_client(old, -1, rxrpc_client_replace); - candidate->bundle_shift = shift; - bundle->conns[i] = candidate; - for (j = 0; j < RXRPC_MAXCALLS; j++) - set_bit(shift + j, &bundle->avail_chans); - candidate = NULL; - break; - } - - old = NULL; - } - - spin_unlock(&bundle->channel_lock); - - if (candidate) { - _debug("discard C=%x", candidate->debug_id); - trace_rxrpc_client(candidate, -1, rxrpc_client_duplicate); - rxrpc_put_connection(candidate); + conn = rxrpc_alloc_client_connection(bundle); + if (IS_ERR(conn)) { + bundle->alloc_error = PTR_ERR(conn); + return false; } - rxrpc_put_connection(old); - _leave(""); + rxrpc_activate_bundle(bundle); + conn->bundle_shift = shift; + bundle->conns[slot] = conn; + for (i = 0; i < RXRPC_MAXCALLS; i++) + set_bit(shift + i, &bundle->avail_chans); + return true; } /* * Add a connection to a bundle if there are no usable connections or we have * connections waiting for extra capacity. */ -static void rxrpc_maybe_add_conn(struct rxrpc_bundle *bundle, gfp_t gfp) +static bool rxrpc_bundle_has_space(struct rxrpc_bundle *bundle) { - struct rxrpc_call *call; - int i, usable; + int slot = -1, i, usable; _enter(""); - spin_lock(&bundle->channel_lock); + bundle->alloc_error = 0; /* See if there are any usable connections. */ usable = 0; - for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) + for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) { if (rxrpc_may_reuse_conn(bundle->conns[i])) usable++; - - if (!usable && !list_empty(&bundle->waiting_calls)) { - call = list_first_entry(&bundle->waiting_calls, - struct rxrpc_call, chan_wait_link); - if (test_bit(RXRPC_CALL_UPGRADE, &call->flags)) - bundle->try_upgrade = true; + else if (slot == -1) + slot = i; } + if (!usable && bundle->upgrade) + bundle->try_upgrade = true; + if (!usable) goto alloc_conn; if (!bundle->avail_chans && !bundle->try_upgrade && - !list_empty(&bundle->waiting_calls) && usable < ARRAY_SIZE(bundle->conns)) goto alloc_conn; - spin_unlock(&bundle->channel_lock); _leave(""); - return; + return usable; alloc_conn: - return rxrpc_add_conn_to_bundle(bundle, gfp); + return slot >= 0 ? rxrpc_add_conn_to_bundle(bundle, slot) : false; } /* @@ -511,11 +405,13 @@ static void rxrpc_activate_one_channel(struct rxrpc_connection *conn, struct rxrpc_channel *chan = &conn->channels[channel]; struct rxrpc_bundle *bundle = conn->bundle; struct rxrpc_call *call = list_entry(bundle->waiting_calls.next, - struct rxrpc_call, chan_wait_link); + struct rxrpc_call, wait_link); u32 call_id = chan->call_counter + 1; _enter("C=%x,%u", conn->debug_id, channel); + list_del_init(&call->wait_link); + trace_rxrpc_client(conn, channel, rxrpc_client_chan_activate); /* Cancel the final ACK on the previous call if it hasn't been sent yet @@ -524,73 +420,51 @@ static void rxrpc_activate_one_channel(struct rxrpc_connection *conn, clear_bit(RXRPC_CONN_FINAL_ACK_0 + channel, &conn->flags); clear_bit(conn->bundle_shift + channel, &bundle->avail_chans); - rxrpc_see_call(call); - list_del_init(&call->chan_wait_link); - call->peer = rxrpc_get_peer(conn->params.peer); - call->conn = rxrpc_get_connection(conn); + rxrpc_see_call(call, rxrpc_call_see_activate_client); + call->conn = rxrpc_get_connection(conn, rxrpc_conn_get_activate_call); call->cid = conn->proto.cid | channel; call->call_id = call_id; - call->security = conn->security; - call->security_ix = conn->security_ix; - call->service_id = conn->service_id; - - trace_rxrpc_connect_call(call); - _net("CONNECT call %08x:%08x as call %d on conn %d", - call->cid, call->call_id, call->debug_id, conn->debug_id); - - write_lock_bh(&call->state_lock); - call->state = RXRPC_CALL_CLIENT_SEND_REQUEST; - write_unlock_bh(&call->state_lock); - - /* Paired with the read barrier in rxrpc_connect_call(). This orders - * cid and epoch in the connection wrt to call_id without the need to - * take the channel_lock. - * - * We provisionally assign a callNumber at this point, but we don't - * confirm it until the call is about to be exposed. - * - * TODO: Pair with a barrier in the data_ready handler when that looks - * at the call ID through a connection channel. - */ - smp_wmb(); + call->dest_srx.srx_service = conn->service_id; + call->cong_ssthresh = call->peer->cong_ssthresh; + if (call->cong_cwnd >= call->cong_ssthresh) + call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; + else + call->cong_mode = RXRPC_CALL_SLOW_START; chan->call_id = call_id; chan->call_debug_id = call->debug_id; - rcu_assign_pointer(chan->call, call); + chan->call = call; + + rxrpc_see_call(call, rxrpc_call_see_connected); + trace_rxrpc_connect_call(call); + call->tx_last_sent = ktime_get_real(); + rxrpc_start_call_timer(call); + rxrpc_set_call_state(call, RXRPC_CALL_CLIENT_SEND_REQUEST); wake_up(&call->waitq); } /* * Remove a connection from the idle list if it's on it. */ -static void rxrpc_unidle_conn(struct rxrpc_bundle *bundle, struct rxrpc_connection *conn) +static void rxrpc_unidle_conn(struct rxrpc_connection *conn) { - struct rxrpc_net *rxnet = bundle->params.local->rxnet; - bool drop_ref; - if (!list_empty(&conn->cache_link)) { - drop_ref = false; - spin_lock(&rxnet->client_conn_cache_lock); - if (!list_empty(&conn->cache_link)) { - list_del_init(&conn->cache_link); - drop_ref = true; - } - spin_unlock(&rxnet->client_conn_cache_lock); - if (drop_ref) - rxrpc_put_connection(conn); + list_del_init(&conn->cache_link); + rxrpc_put_connection(conn, rxrpc_conn_put_unidle); } } /* - * Assign channels and callNumbers to waiting calls with channel_lock - * held by caller. + * Assign channels and callNumbers to waiting calls. */ -static void rxrpc_activate_channels_locked(struct rxrpc_bundle *bundle) +static void rxrpc_activate_channels(struct rxrpc_bundle *bundle) { struct rxrpc_connection *conn; unsigned long avail, mask; unsigned int channel, slot; + trace_rxrpc_client(NULL, -1, rxrpc_client_activate_chans); + if (bundle->try_upgrade) mask = 1; else @@ -610,7 +484,7 @@ static void rxrpc_activate_channels_locked(struct rxrpc_bundle *bundle) if (bundle->try_upgrade) set_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags); - rxrpc_unidle_conn(bundle, conn); + rxrpc_unidle_conn(conn); channel &= (RXRPC_MAXCALLS - 1); conn->act_chans |= 1 << channel; @@ -619,131 +493,24 @@ static void rxrpc_activate_channels_locked(struct rxrpc_bundle *bundle) } /* - * Assign channels and callNumbers to waiting calls. - */ -static void rxrpc_activate_channels(struct rxrpc_bundle *bundle) -{ - _enter("B=%x", bundle->debug_id); - - trace_rxrpc_client(NULL, -1, rxrpc_client_activate_chans); - - if (!bundle->avail_chans) - return; - - spin_lock(&bundle->channel_lock); - rxrpc_activate_channels_locked(bundle); - spin_unlock(&bundle->channel_lock); - _leave(""); -} - -/* - * Wait for a callNumber and a channel to be granted to a call. - */ -static int rxrpc_wait_for_channel(struct rxrpc_bundle *bundle, - struct rxrpc_call *call, gfp_t gfp) -{ - DECLARE_WAITQUEUE(myself, current); - int ret = 0; - - _enter("%d", call->debug_id); - - if (!gfpflags_allow_blocking(gfp)) { - rxrpc_maybe_add_conn(bundle, gfp); - rxrpc_activate_channels(bundle); - ret = bundle->alloc_error ?: -EAGAIN; - goto out; - } - - add_wait_queue_exclusive(&call->waitq, &myself); - for (;;) { - rxrpc_maybe_add_conn(bundle, gfp); - rxrpc_activate_channels(bundle); - ret = bundle->alloc_error; - if (ret < 0) - break; - - switch (call->interruptibility) { - case RXRPC_INTERRUPTIBLE: - case RXRPC_PREINTERRUPTIBLE: - set_current_state(TASK_INTERRUPTIBLE); - break; - case RXRPC_UNINTERRUPTIBLE: - default: - set_current_state(TASK_UNINTERRUPTIBLE); - break; - } - if (READ_ONCE(call->state) != RXRPC_CALL_CLIENT_AWAIT_CONN) - break; - if ((call->interruptibility == RXRPC_INTERRUPTIBLE || - call->interruptibility == RXRPC_PREINTERRUPTIBLE) && - signal_pending(current)) { - ret = -ERESTARTSYS; - break; - } - schedule(); - } - remove_wait_queue(&call->waitq, &myself); - __set_current_state(TASK_RUNNING); - -out: - _leave(" = %d", ret); - return ret; -} - -/* - * find a connection for a call - * - called in process context with IRQs enabled + * Connect waiting channels (called from the I/O thread). */ -int rxrpc_connect_call(struct rxrpc_sock *rx, - struct rxrpc_call *call, - struct rxrpc_conn_parameters *cp, - struct sockaddr_rxrpc *srx, - gfp_t gfp) +void rxrpc_connect_client_calls(struct rxrpc_local *local) { - struct rxrpc_bundle *bundle; - struct rxrpc_net *rxnet = cp->local->rxnet; - int ret = 0; - - _enter("{%d,%lx},", call->debug_id, call->user_call_ID); - - rxrpc_discard_expired_client_conns(&rxnet->client_conn_reaper); - - bundle = rxrpc_prep_call(rx, call, cp, srx, gfp); - if (IS_ERR(bundle)) { - ret = PTR_ERR(bundle); - goto out; - } - - if (call->state == RXRPC_CALL_CLIENT_AWAIT_CONN) { - ret = rxrpc_wait_for_channel(bundle, call, gfp); - if (ret < 0) - goto wait_failed; - } - -granted_channel: - /* Paired with the write barrier in rxrpc_activate_one_channel(). */ - smp_rmb(); + struct rxrpc_call *call; -out_put_bundle: - rxrpc_put_bundle(bundle); -out: - _leave(" = %d", ret); - return ret; + while ((call = list_first_entry_or_null(&local->new_client_calls, + struct rxrpc_call, wait_link)) + ) { + struct rxrpc_bundle *bundle = call->bundle; -wait_failed: - spin_lock(&bundle->channel_lock); - list_del_init(&call->chan_wait_link); - spin_unlock(&bundle->channel_lock); + spin_lock(&local->client_call_lock); + list_move_tail(&call->wait_link, &bundle->waiting_calls); + spin_unlock(&local->client_call_lock); - if (call->state != RXRPC_CALL_CLIENT_AWAIT_CONN) { - ret = 0; - goto granted_channel; + if (rxrpc_bundle_has_space(bundle)) + rxrpc_activate_channels(bundle); } - - trace_rxrpc_client(call->conn, ret, rxrpc_client_chan_wait_failed); - rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, 0, ret); - rxrpc_disconnect_client_call(bundle, call); - goto out_put_bundle; } /* @@ -766,20 +533,24 @@ void rxrpc_expose_client_call(struct rxrpc_call *call) if (chan->call_counter >= INT_MAX) set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); trace_rxrpc_client(conn, channel, rxrpc_client_exposed); + + spin_lock(&call->peer->lock); + hlist_add_head(&call->error_link, &call->peer->error_targets); + spin_unlock(&call->peer->lock); } } /* * Set the reap timer. */ -static void rxrpc_set_client_reap_timer(struct rxrpc_net *rxnet) +static void rxrpc_set_client_reap_timer(struct rxrpc_local *local) { - if (!rxnet->kill_all_client_conns) { + if (!local->kill_all_client_conns) { unsigned long now = jiffies; unsigned long reap_at = now + rxrpc_conn_idle_client_expiry; - if (rxnet->live) - timer_reduce(&rxnet->client_conn_reap_timer, reap_at); + if (local->rxnet->live) + timer_reduce(&local->client_conn_reap_timer, reap_at); } } @@ -790,16 +561,13 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call { struct rxrpc_connection *conn; struct rxrpc_channel *chan = NULL; - struct rxrpc_net *rxnet = bundle->params.local->rxnet; + struct rxrpc_local *local = bundle->local; unsigned int channel; bool may_reuse; u32 cid; _enter("c=%x", call->debug_id); - spin_lock(&bundle->channel_lock); - set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); - /* Calls that have never actually been assigned a channel can simply be * discarded. */ @@ -808,8 +576,8 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call _debug("call is waiting"); ASSERTCMP(call->call_id, ==, 0); ASSERT(!test_bit(RXRPC_CALL_EXPOSED, &call->flags)); - list_del_init(&call->chan_wait_link); - goto out; + list_del_init(&call->wait_link); + return; } cid = call->cid; @@ -817,10 +585,8 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call chan = &conn->channels[channel]; trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect); - if (rcu_access_pointer(chan->call) != call) { - spin_unlock(&bundle->channel_lock); - BUG(); - } + if (WARN_ON(chan->call != call)) + return; may_reuse = rxrpc_may_reuse_conn(conn); @@ -841,16 +607,15 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call trace_rxrpc_client(conn, channel, rxrpc_client_to_active); bundle->try_upgrade = false; if (may_reuse) - rxrpc_activate_channels_locked(bundle); + rxrpc_activate_channels(bundle); } - } /* See if we can pass the channel directly to another call. */ if (may_reuse && !list_empty(&bundle->waiting_calls)) { trace_rxrpc_client(conn, channel, rxrpc_client_chan_pass); rxrpc_activate_one_channel(conn, channel); - goto out; + return; } /* Schedule the final ACK to be transmitted in a short while so that it @@ -868,7 +633,7 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call } /* Deactivate the channel. */ - rcu_assign_pointer(chan->call, NULL); + chan->call = NULL; set_bit(conn->bundle_shift + channel, &conn->bundle->avail_chans); conn->act_chans &= ~(1 << channel); @@ -880,18 +645,11 @@ void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call trace_rxrpc_client(conn, channel, rxrpc_client_to_idle); conn->idle_timestamp = jiffies; - rxrpc_get_connection(conn); - spin_lock(&rxnet->client_conn_cache_lock); - list_move_tail(&conn->cache_link, &rxnet->idle_client_conns); - spin_unlock(&rxnet->client_conn_cache_lock); + rxrpc_get_connection(conn, rxrpc_conn_get_idle); + list_move_tail(&conn->cache_link, &local->idle_client_conns); - rxrpc_set_client_reap_timer(rxnet); + rxrpc_set_client_reap_timer(local); } - -out: - spin_unlock(&bundle->channel_lock); - _leave(""); - return; } /* @@ -900,9 +658,7 @@ out: static void rxrpc_unbundle_conn(struct rxrpc_connection *conn) { struct rxrpc_bundle *bundle = conn->bundle; - struct rxrpc_local *local = bundle->params.local; unsigned int bindex; - bool need_drop = false, need_put = false; int i; _enter("C=%x", conn->debug_id); @@ -910,26 +666,32 @@ static void rxrpc_unbundle_conn(struct rxrpc_connection *conn) if (conn->flags & RXRPC_CONN_FINAL_ACK_MASK) rxrpc_process_delayed_final_acks(conn, true); - spin_lock(&bundle->channel_lock); bindex = conn->bundle_shift / RXRPC_MAXCALLS; if (bundle->conns[bindex] == conn) { _debug("clear slot %u", bindex); bundle->conns[bindex] = NULL; for (i = 0; i < RXRPC_MAXCALLS; i++) clear_bit(conn->bundle_shift + i, &bundle->avail_chans); - need_drop = true; + rxrpc_put_client_connection_id(bundle->local, conn); + rxrpc_deactivate_bundle(bundle); + rxrpc_put_connection(conn, rxrpc_conn_put_unbundle); } - spin_unlock(&bundle->channel_lock); +} - /* If there are no more connections, remove the bundle */ - if (!bundle->avail_chans) { - _debug("maybe unbundle"); - spin_lock(&local->client_bundles_lock); +/* + * Drop the active count on a bundle. + */ +void rxrpc_deactivate_bundle(struct rxrpc_bundle *bundle) +{ + struct rxrpc_local *local; + bool need_put = false; + + if (!bundle) + return; - for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) - if (bundle->conns[i]) - break; - if (i == ARRAY_SIZE(bundle->conns) && !bundle->params.exclusive) { + local = bundle->local; + if (atomic_dec_and_lock(&bundle->active, &local->client_bundles_lock)) { + if (!bundle->exclusive) { _debug("erase bundle"); rb_erase(&bundle->local_node, &local->client_bundles); need_put = true; @@ -937,20 +699,16 @@ static void rxrpc_unbundle_conn(struct rxrpc_connection *conn) spin_unlock(&local->client_bundles_lock); if (need_put) - rxrpc_put_bundle(bundle); + rxrpc_put_bundle(bundle, rxrpc_bundle_put_discard); } - - if (need_drop) - rxrpc_put_connection(conn); - _leave(""); } /* * Clean up a dead client connection. */ -static void rxrpc_kill_client_conn(struct rxrpc_connection *conn) +void rxrpc_kill_client_conn(struct rxrpc_connection *conn) { - struct rxrpc_local *local = conn->params.local; + struct rxrpc_local *local = conn->local; struct rxrpc_net *rxnet = local->rxnet; _enter("C=%x", conn->debug_id); @@ -958,24 +716,7 @@ static void rxrpc_kill_client_conn(struct rxrpc_connection *conn) trace_rxrpc_client(conn, -1, rxrpc_client_cleanup); atomic_dec(&rxnet->nr_client_conns); - rxrpc_put_client_connection_id(conn); - rxrpc_kill_connection(conn); -} - -/* - * Clean up a dead client connections. - */ -void rxrpc_put_client_conn(struct rxrpc_connection *conn) -{ - const void *here = __builtin_return_address(0); - unsigned int debug_id = conn->debug_id; - bool dead; - int r; - - dead = __refcount_dec_and_test(&conn->ref, &r); - trace_rxrpc_conn(debug_id, rxrpc_conn_put_client, r - 1, here); - if (dead) - rxrpc_kill_client_conn(conn); + rxrpc_put_client_connection_id(local, conn); } /* @@ -985,42 +726,26 @@ void rxrpc_put_client_conn(struct rxrpc_connection *conn) * This may be called from conn setup or from a work item so cannot be * considered non-reentrant. */ -void rxrpc_discard_expired_client_conns(struct work_struct *work) +void rxrpc_discard_expired_client_conns(struct rxrpc_local *local) { struct rxrpc_connection *conn; - struct rxrpc_net *rxnet = - container_of(work, struct rxrpc_net, client_conn_reaper); unsigned long expiry, conn_expires_at, now; unsigned int nr_conns; _enter(""); - if (list_empty(&rxnet->idle_client_conns)) { - _leave(" [empty]"); - return; - } - - /* Don't double up on the discarding */ - if (!spin_trylock(&rxnet->client_conn_discard_lock)) { - _leave(" [already]"); - return; - } - /* We keep an estimate of what the number of conns ought to be after * we've discarded some so that we don't overdo the discarding. */ - nr_conns = atomic_read(&rxnet->nr_client_conns); + nr_conns = atomic_read(&local->rxnet->nr_client_conns); next: - spin_lock(&rxnet->client_conn_cache_lock); - - if (list_empty(&rxnet->idle_client_conns)) - goto out; - - conn = list_entry(rxnet->idle_client_conns.next, - struct rxrpc_connection, cache_link); + conn = list_first_entry_or_null(&local->idle_client_conns, + struct rxrpc_connection, cache_link); + if (!conn) + return; - if (!rxnet->kill_all_client_conns) { + if (!local->kill_all_client_conns) { /* If the number of connections is over the reap limit, we * expedite discard by reducing the expiry timeout. We must, * however, have at least a short grace period to be able to do @@ -1029,7 +754,7 @@ next: expiry = rxrpc_conn_idle_client_expiry; if (nr_conns > rxrpc_reap_client_connections) expiry = rxrpc_conn_idle_client_fast_expiry; - if (conn->params.local->service_closed) + if (conn->local->service_closed) expiry = rxrpc_closed_conn_expiry * HZ; conn_expires_at = conn->idle_timestamp + expiry; @@ -1039,13 +764,13 @@ next: goto not_yet_expired; } + atomic_dec(&conn->active); trace_rxrpc_client(conn, -1, rxrpc_client_discard); list_del_init(&conn->cache_link); - spin_unlock(&rxnet->client_conn_cache_lock); - rxrpc_unbundle_conn(conn); - rxrpc_put_connection(conn); /* Drop the ->cache_link ref */ + /* Drop the ->cache_link ref */ + rxrpc_put_connection(conn, rxrpc_conn_put_discard_idle); nr_conns--; goto next; @@ -1059,31 +784,8 @@ not_yet_expired: * then things get messier. */ _debug("not yet"); - if (!rxnet->kill_all_client_conns) - timer_reduce(&rxnet->client_conn_reap_timer, conn_expires_at); - -out: - spin_unlock(&rxnet->client_conn_cache_lock); - spin_unlock(&rxnet->client_conn_discard_lock); - _leave(""); -} - -/* - * Preemptively destroy all the client connection records rather than waiting - * for them to time out - */ -void rxrpc_destroy_all_client_connections(struct rxrpc_net *rxnet) -{ - _enter(""); - - spin_lock(&rxnet->client_conn_cache_lock); - rxnet->kill_all_client_conns = true; - spin_unlock(&rxnet->client_conn_cache_lock); - - del_timer_sync(&rxnet->client_conn_reap_timer); - - if (!rxrpc_queue_work(&rxnet->client_conn_reaper)) - _debug("destroy: queue failed"); + if (!local->kill_all_client_conns) + timer_reduce(&local->client_conn_reap_timer, conn_expires_at); _leave(""); } @@ -1093,30 +795,21 @@ void rxrpc_destroy_all_client_connections(struct rxrpc_net *rxnet) */ void rxrpc_clean_up_local_conns(struct rxrpc_local *local) { - struct rxrpc_connection *conn, *tmp; - struct rxrpc_net *rxnet = local->rxnet; - LIST_HEAD(graveyard); + struct rxrpc_connection *conn; _enter(""); - spin_lock(&rxnet->client_conn_cache_lock); - - list_for_each_entry_safe(conn, tmp, &rxnet->idle_client_conns, - cache_link) { - if (conn->params.local == local) { - trace_rxrpc_client(conn, -1, rxrpc_client_discard); - list_move(&conn->cache_link, &graveyard); - } - } + local->kill_all_client_conns = true; - spin_unlock(&rxnet->client_conn_cache_lock); + del_timer_sync(&local->client_conn_reap_timer); - while (!list_empty(&graveyard)) { - conn = list_entry(graveyard.next, - struct rxrpc_connection, cache_link); + while ((conn = list_first_entry_or_null(&local->idle_client_conns, + struct rxrpc_connection, cache_link))) { list_del_init(&conn->cache_link); + atomic_dec(&conn->active); + trace_rxrpc_client(conn, -1, rxrpc_client_discard); rxrpc_unbundle_conn(conn); - rxrpc_put_connection(conn); + rxrpc_put_connection(conn, rxrpc_conn_put_local_dead); } _leave(" [culled]"); diff --git a/net/rxrpc/conn_event.c b/net/rxrpc/conn_event.c index aab069701398..44414e724415 100644 --- a/net/rxrpc/conn_event.c +++ b/net/rxrpc/conn_event.c @@ -17,11 +17,65 @@ #include "ar-internal.h" /* + * Set the completion state on an aborted connection. + */ +static bool rxrpc_set_conn_aborted(struct rxrpc_connection *conn, struct sk_buff *skb, + s32 abort_code, int err, + enum rxrpc_call_completion compl) +{ + bool aborted = false; + + if (conn->state != RXRPC_CONN_ABORTED) { + spin_lock(&conn->state_lock); + if (conn->state != RXRPC_CONN_ABORTED) { + conn->abort_code = abort_code; + conn->error = err; + conn->completion = compl; + /* Order the abort info before the state change. */ + smp_store_release(&conn->state, RXRPC_CONN_ABORTED); + set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); + set_bit(RXRPC_CONN_EV_ABORT_CALLS, &conn->events); + aborted = true; + } + spin_unlock(&conn->state_lock); + } + + return aborted; +} + +/* + * Mark a socket buffer to indicate that the connection it's on should be aborted. + */ +int rxrpc_abort_conn(struct rxrpc_connection *conn, struct sk_buff *skb, + s32 abort_code, int err, enum rxrpc_abort_reason why) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + if (rxrpc_set_conn_aborted(conn, skb, abort_code, err, + RXRPC_CALL_LOCALLY_ABORTED)) { + trace_rxrpc_abort(0, why, sp->hdr.cid, sp->hdr.callNumber, + sp->hdr.seq, abort_code, err); + rxrpc_poke_conn(conn, rxrpc_conn_get_poke_abort); + } + return -EPROTO; +} + +/* + * Mark a connection as being remotely aborted. + */ +static bool rxrpc_input_conn_abort(struct rxrpc_connection *conn, + struct sk_buff *skb) +{ + return rxrpc_set_conn_aborted(conn, skb, skb->priority, -ECONNABORTED, + RXRPC_CALL_REMOTELY_ABORTED); +} + +/* * Retransmit terminal ACK or ABORT of the previous call. */ -static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, - struct sk_buff *skb, - unsigned int channel) +void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, + struct sk_buff *skb, + unsigned int channel) { struct rxrpc_skb_priv *sp = skb ? rxrpc_skb(skb) : NULL; struct rxrpc_channel *chan; @@ -46,14 +100,12 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, /* If the last call got moved on whilst we were waiting to run, just * ignore this packet. */ - call_id = READ_ONCE(chan->last_call); - /* Sync with __rxrpc_disconnect_call() */ - smp_rmb(); + call_id = chan->last_call; if (skb && call_id != sp->hdr.callNumber) return; - msg.msg_name = &conn->params.peer->srx.transport; - msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_name = &conn->peer->srx.transport; + msg.msg_namelen = conn->peer->srx.transport_len; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; @@ -65,9 +117,12 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, iov[2].iov_base = &ack_info; iov[2].iov_len = sizeof(ack_info); + serial = atomic_inc_return(&conn->serial); + pkt.whdr.epoch = htonl(conn->proto.epoch); pkt.whdr.cid = htonl(conn->proto.cid | channel); pkt.whdr.callNumber = htonl(call_id); + pkt.whdr.serial = htonl(serial); pkt.whdr.seq = 0; pkt.whdr.type = chan->last_type; pkt.whdr.flags = conn->out_clientflag; @@ -86,8 +141,8 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, break; case RXRPC_PACKET_TYPE_ACK: - mtu = conn->params.peer->if_mtu; - mtu -= conn->params.peer->hdrsize; + mtu = conn->peer->if_mtu; + mtu -= conn->peer->hdrsize; pkt.ack.bufferSpace = 0; pkt.ack.maxSkew = htons(skb ? skb->priority : 0); pkt.ack.firstPacket = htonl(chan->last_seq + 1); @@ -104,37 +159,19 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, iov[0].iov_len += sizeof(pkt.ack); len += sizeof(pkt.ack) + 3 + sizeof(ack_info); ioc = 3; - break; - - default: - return; - } - /* Resync with __rxrpc_disconnect_call() and check that the last call - * didn't get advanced whilst we were filling out the packets. - */ - smp_rmb(); - if (READ_ONCE(chan->last_call) != call_id) - return; - - serial = atomic_inc_return(&conn->serial); - pkt.whdr.serial = htonl(serial); - - switch (chan->last_type) { - case RXRPC_PACKET_TYPE_ABORT: - _proto("Tx ABORT %%%u { %d } [re]", serial, conn->abort_code); - break; - case RXRPC_PACKET_TYPE_ACK: trace_rxrpc_tx_ack(chan->call_debug_id, serial, ntohl(pkt.ack.firstPacket), ntohl(pkt.ack.serial), pkt.ack.reason, 0); - _proto("Tx ACK %%%u [re]", serial); break; + + default: + return; } - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, ioc, len); - conn->params.peer->last_tx_at = ktime_get_seconds(); + ret = kernel_sendmsg(conn->local->socket, &msg, iov, ioc, len); + conn->peer->last_tx_at = ktime_get_seconds(); if (ret < 0) trace_rxrpc_tx_fail(chan->call_debug_id, serial, ret, rxrpc_tx_point_call_final_resend); @@ -148,132 +185,34 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, /* * pass a connection-level abort onto all calls on that connection */ -static void rxrpc_abort_calls(struct rxrpc_connection *conn, - enum rxrpc_call_completion compl, - rxrpc_serial_t serial) +static void rxrpc_abort_calls(struct rxrpc_connection *conn) { struct rxrpc_call *call; int i; _enter("{%d},%x", conn->debug_id, conn->abort_code); - spin_lock(&conn->bundle->channel_lock); - for (i = 0; i < RXRPC_MAXCALLS; i++) { - call = rcu_dereference_protected( - conn->channels[i].call, - lockdep_is_held(&conn->bundle->channel_lock)); - if (call) { - if (compl == RXRPC_CALL_LOCALLY_ABORTED) - trace_rxrpc_abort(call->debug_id, - "CON", call->cid, - call->call_id, 0, + call = conn->channels[i].call; + if (call) + rxrpc_set_call_completion(call, + conn->completion, conn->abort_code, conn->error); - else - trace_rxrpc_rx_abort(call, serial, - conn->abort_code); - rxrpc_set_call_completion(call, compl, - conn->abort_code, - conn->error); - } } - spin_unlock(&conn->bundle->channel_lock); _leave(""); } /* - * generate a connection-level abort - */ -static int rxrpc_abort_connection(struct rxrpc_connection *conn, - int error, u32 abort_code) -{ - struct rxrpc_wire_header whdr; - struct msghdr msg; - struct kvec iov[2]; - __be32 word; - size_t len; - u32 serial; - int ret; - - _enter("%d,,%u,%u", conn->debug_id, error, abort_code); - - /* generate a connection-level abort */ - spin_lock_bh(&conn->state_lock); - if (conn->state >= RXRPC_CONN_REMOTELY_ABORTED) { - spin_unlock_bh(&conn->state_lock); - _leave(" = 0 [already dead]"); - return 0; - } - - conn->error = error; - conn->abort_code = abort_code; - conn->state = RXRPC_CONN_LOCALLY_ABORTED; - set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); - spin_unlock_bh(&conn->state_lock); - - msg.msg_name = &conn->params.peer->srx.transport; - msg.msg_namelen = conn->params.peer->srx.transport_len; - msg.msg_control = NULL; - msg.msg_controllen = 0; - msg.msg_flags = 0; - - whdr.epoch = htonl(conn->proto.epoch); - whdr.cid = htonl(conn->proto.cid); - whdr.callNumber = 0; - whdr.seq = 0; - whdr.type = RXRPC_PACKET_TYPE_ABORT; - whdr.flags = conn->out_clientflag; - whdr.userStatus = 0; - whdr.securityIndex = conn->security_ix; - whdr._rsvd = 0; - whdr.serviceId = htons(conn->service_id); - - word = htonl(conn->abort_code); - - iov[0].iov_base = &whdr; - iov[0].iov_len = sizeof(whdr); - iov[1].iov_base = &word; - iov[1].iov_len = sizeof(word); - - len = iov[0].iov_len + iov[1].iov_len; - - serial = atomic_inc_return(&conn->serial); - rxrpc_abort_calls(conn, RXRPC_CALL_LOCALLY_ABORTED, serial); - whdr.serial = htonl(serial); - _proto("Tx CONN ABORT %%%u { %d }", serial, conn->abort_code); - - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); - if (ret < 0) { - trace_rxrpc_tx_fail(conn->debug_id, serial, ret, - rxrpc_tx_point_conn_abort); - _debug("sendmsg failed: %d", ret); - return -EAGAIN; - } - - trace_rxrpc_tx_packet(conn->debug_id, &whdr, rxrpc_tx_point_conn_abort); - - conn->params.peer->last_tx_at = ktime_get_seconds(); - - _leave(" = 0"); - return 0; -} - -/* * mark a call as being on a now-secured channel * - must be called with BH's disabled. */ static void rxrpc_call_is_secure(struct rxrpc_call *call) { - _enter("%p", call); - if (call) { - write_lock_bh(&call->state_lock); - if (call->state == RXRPC_CALL_SERVER_SECURING) { - call->state = RXRPC_CALL_SERVER_RECV_REQUEST; - rxrpc_notify_socket(call); - } - write_unlock_bh(&call->state_lock); + if (call && __rxrpc_call_state(call) == RXRPC_CALL_SERVER_SECURING) { + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_RECV_REQUEST); + rxrpc_notify_socket(call); } } @@ -281,84 +220,49 @@ static void rxrpc_call_is_secure(struct rxrpc_call *call) * connection-level Rx packet processor */ static int rxrpc_process_event(struct rxrpc_connection *conn, - struct sk_buff *skb, - u32 *_abort_code) + struct sk_buff *skb) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - __be32 wtmp; - u32 abort_code; - int loop, ret; + int ret; - if (conn->state >= RXRPC_CONN_REMOTELY_ABORTED) { - _leave(" = -ECONNABORTED [%u]", conn->state); + if (conn->state == RXRPC_CONN_ABORTED) return -ECONNABORTED; - } _enter("{%d},{%u,%%%u},", conn->debug_id, sp->hdr.type, sp->hdr.serial); switch (sp->hdr.type) { - case RXRPC_PACKET_TYPE_DATA: - case RXRPC_PACKET_TYPE_ACK: - rxrpc_conn_retransmit_call(conn, skb, - sp->hdr.cid & RXRPC_CHANNELMASK); - return 0; - - case RXRPC_PACKET_TYPE_BUSY: - /* Just ignore BUSY packets for now. */ - return 0; - - case RXRPC_PACKET_TYPE_ABORT: - if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), - &wtmp, sizeof(wtmp)) < 0) { - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, - tracepoint_string("bad_abort")); - return -EPROTO; - } - abort_code = ntohl(wtmp); - _proto("Rx ABORT %%%u { ac=%d }", sp->hdr.serial, abort_code); - - conn->error = -ECONNABORTED; - conn->abort_code = abort_code; - conn->state = RXRPC_CONN_REMOTELY_ABORTED; - set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); - rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED, sp->hdr.serial); - return -ECONNABORTED; - case RXRPC_PACKET_TYPE_CHALLENGE: - return conn->security->respond_to_challenge(conn, skb, - _abort_code); + return conn->security->respond_to_challenge(conn, skb); case RXRPC_PACKET_TYPE_RESPONSE: - ret = conn->security->verify_response(conn, skb, _abort_code); + ret = conn->security->verify_response(conn, skb); if (ret < 0) return ret; ret = conn->security->init_connection_security( - conn, conn->params.key->payload.data[0]); + conn, conn->key->payload.data[0]); if (ret < 0) return ret; - spin_lock(&conn->bundle->channel_lock); - spin_lock_bh(&conn->state_lock); - - if (conn->state == RXRPC_CONN_SERVICE_CHALLENGING) { + spin_lock(&conn->state_lock); + if (conn->state == RXRPC_CONN_SERVICE_CHALLENGING) conn->state = RXRPC_CONN_SERVICE; - spin_unlock_bh(&conn->state_lock); - for (loop = 0; loop < RXRPC_MAXCALLS; loop++) - rxrpc_call_is_secure( - rcu_dereference_protected( - conn->channels[loop].call, - lockdep_is_held(&conn->bundle->channel_lock))); - } else { - spin_unlock_bh(&conn->state_lock); + spin_unlock(&conn->state_lock); + + if (conn->state == RXRPC_CONN_SERVICE) { + /* Offload call state flipping to the I/O thread. As + * we've already received the packet, put it on the + * front of the queue. + */ + skb->mark = RXRPC_SKB_MARK_SERVICE_CONN_SECURED; + rxrpc_get_skb(skb, rxrpc_skb_get_conn_secured); + skb_queue_head(&conn->local->rx_queue, skb); + rxrpc_wake_up_io_thread(conn->local); } - - spin_unlock(&conn->bundle->channel_lock); return 0; default: - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, - tracepoint_string("bad_conn_pkt")); + WARN_ON_ONCE(1); return -EPROTO; } } @@ -368,26 +272,9 @@ static int rxrpc_process_event(struct rxrpc_connection *conn, */ static void rxrpc_secure_connection(struct rxrpc_connection *conn) { - u32 abort_code; - int ret; - - _enter("{%d}", conn->debug_id); - - ASSERT(conn->security_ix != 0); - - if (conn->security->issue_challenge(conn) < 0) { - abort_code = RX_CALL_DEAD; - ret = -ENOMEM; - goto abort; - } - - _leave(""); - return; - -abort: - _debug("abort %d, %d", ret, abort_code); - rxrpc_abort_connection(conn, ret, abort_code); - _leave(" [aborted]"); + if (conn->security->issue_challenge(conn) < 0) + rxrpc_abort_conn(conn, NULL, RX_CALL_DEAD, -ENOMEM, + rxrpc_abort_nomem); } /* @@ -409,9 +296,7 @@ again: if (!test_bit(RXRPC_CONN_FINAL_ACK_0 + channel, &conn->flags)) continue; - smp_rmb(); /* vs rxrpc_disconnect_client_call */ - ack_at = READ_ONCE(chan->final_ack_at); - + ack_at = chan->final_ack_at; if (time_before(j, ack_at) && !force) { if (time_before(ack_at, next_j)) { next_j = ack_at; @@ -438,47 +323,27 @@ again: static void rxrpc_do_process_connection(struct rxrpc_connection *conn) { struct sk_buff *skb; - u32 abort_code = RX_PROTOCOL_ERROR; int ret; if (test_and_clear_bit(RXRPC_CONN_EV_CHALLENGE, &conn->events)) rxrpc_secure_connection(conn); - /* Process delayed ACKs whose time has come. */ - if (conn->flags & RXRPC_CONN_FINAL_ACK_MASK) - rxrpc_process_delayed_final_acks(conn, false); - /* go through the conn-level event packets, releasing the ref on this * connection that each one has when we've finished with it */ while ((skb = skb_dequeue(&conn->rx_queue))) { - rxrpc_see_skb(skb, rxrpc_skb_seen); - ret = rxrpc_process_event(conn, skb, &abort_code); + rxrpc_see_skb(skb, rxrpc_skb_see_conn_work); + ret = rxrpc_process_event(conn, skb); switch (ret) { - case -EPROTO: - case -EKEYEXPIRED: - case -EKEYREJECTED: - goto protocol_error; case -ENOMEM: case -EAGAIN: - goto requeue_and_leave; - case -ECONNABORTED: + skb_queue_head(&conn->rx_queue, skb); + rxrpc_queue_conn(conn, rxrpc_conn_queue_retry_work); + break; default: - rxrpc_free_skb(skb, rxrpc_skb_freed); + rxrpc_free_skb(skb, rxrpc_skb_put_conn_work); break; } } - - return; - -requeue_and_leave: - skb_queue_head(&conn->rx_queue, skb); - return; - -protocol_error: - if (rxrpc_abort_connection(conn, ret, abort_code) < 0) - goto requeue_and_leave; - rxrpc_free_skb(skb, rxrpc_skb_freed); - return; } void rxrpc_process_connection(struct work_struct *work) @@ -486,14 +351,85 @@ void rxrpc_process_connection(struct work_struct *work) struct rxrpc_connection *conn = container_of(work, struct rxrpc_connection, processor); - rxrpc_see_connection(conn); + rxrpc_see_connection(conn, rxrpc_conn_see_work); - if (__rxrpc_use_local(conn->params.local)) { + if (__rxrpc_use_local(conn->local, rxrpc_local_use_conn_work)) { rxrpc_do_process_connection(conn); - rxrpc_unuse_local(conn->params.local); + rxrpc_unuse_local(conn->local, rxrpc_local_unuse_conn_work); } +} - rxrpc_put_connection(conn); - _leave(""); - return; +/* + * post connection-level events to the connection + * - this includes challenges, responses, some aborts and call terminal packet + * retransmission. + */ +static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn, + struct sk_buff *skb) +{ + _enter("%p,%p", conn, skb); + + rxrpc_get_skb(skb, rxrpc_skb_get_conn_work); + skb_queue_tail(&conn->rx_queue, skb); + rxrpc_queue_conn(conn, rxrpc_conn_queue_rx_work); +} + +/* + * Input a connection-level packet. + */ +bool rxrpc_input_conn_packet(struct rxrpc_connection *conn, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_BUSY: + /* Just ignore BUSY packets for now. */ + return true; + + case RXRPC_PACKET_TYPE_ABORT: + if (rxrpc_is_conn_aborted(conn)) + return true; + rxrpc_input_conn_abort(conn, skb); + rxrpc_abort_calls(conn); + return true; + + case RXRPC_PACKET_TYPE_CHALLENGE: + case RXRPC_PACKET_TYPE_RESPONSE: + if (rxrpc_is_conn_aborted(conn)) { + if (conn->completion == RXRPC_CALL_LOCALLY_ABORTED) + rxrpc_send_conn_abort(conn); + return true; + } + rxrpc_post_packet_to_conn(conn, skb); + return true; + + default: + WARN_ON_ONCE(1); + return true; + } +} + +/* + * Input a connection event. + */ +void rxrpc_input_conn_event(struct rxrpc_connection *conn, struct sk_buff *skb) +{ + unsigned int loop; + + if (test_and_clear_bit(RXRPC_CONN_EV_ABORT_CALLS, &conn->events)) + rxrpc_abort_calls(conn); + + switch (skb->mark) { + case RXRPC_SKB_MARK_SERVICE_CONN_SECURED: + if (conn->state != RXRPC_CONN_SERVICE) + break; + + for (loop = 0; loop < RXRPC_MAXCALLS; loop++) + rxrpc_call_is_secure(conn->channels[loop].call); + break; + } + + /* Process delayed ACKs whose time has come. */ + if (conn->flags & RXRPC_CONN_FINAL_ACK_MASK) + rxrpc_process_delayed_final_acks(conn, false); } diff --git a/net/rxrpc/conn_object.c b/net/rxrpc/conn_object.c index 22089e37e97f..ac85d4644a3c 100644 --- a/net/rxrpc/conn_object.c +++ b/net/rxrpc/conn_object.c @@ -19,20 +19,41 @@ unsigned int __read_mostly rxrpc_connection_expiry = 10 * 60; unsigned int __read_mostly rxrpc_closed_conn_expiry = 10; -static void rxrpc_destroy_connection(struct rcu_head *); +static void rxrpc_clean_up_connection(struct work_struct *work); +static void rxrpc_set_service_reap_timer(struct rxrpc_net *rxnet, + unsigned long reap_at); + +void rxrpc_poke_conn(struct rxrpc_connection *conn, enum rxrpc_conn_trace why) +{ + struct rxrpc_local *local = conn->local; + bool busy; + + if (WARN_ON_ONCE(!local)) + return; + + spin_lock_bh(&local->lock); + busy = !list_empty(&conn->attend_link); + if (!busy) { + rxrpc_get_connection(conn, why); + list_add_tail(&conn->attend_link, &local->conn_attend_q); + } + spin_unlock_bh(&local->lock); + rxrpc_wake_up_io_thread(local); +} static void rxrpc_connection_timer(struct timer_list *timer) { struct rxrpc_connection *conn = container_of(timer, struct rxrpc_connection, timer); - rxrpc_queue_conn(conn); + rxrpc_poke_conn(conn, rxrpc_conn_get_poke_timer); } /* * allocate a new connection */ -struct rxrpc_connection *rxrpc_alloc_connection(gfp_t gfp) +struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *rxnet, + gfp_t gfp) { struct rxrpc_connection *conn; @@ -42,10 +63,13 @@ struct rxrpc_connection *rxrpc_alloc_connection(gfp_t gfp) if (conn) { INIT_LIST_HEAD(&conn->cache_link); timer_setup(&conn->timer, &rxrpc_connection_timer, 0); - INIT_WORK(&conn->processor, &rxrpc_process_connection); + INIT_WORK(&conn->processor, rxrpc_process_connection); + INIT_WORK(&conn->destructor, rxrpc_clean_up_connection); INIT_LIST_HEAD(&conn->proc_link); INIT_LIST_HEAD(&conn->link); + mutex_init(&conn->security_lock); skb_queue_head_init(&conn->rx_queue); + conn->rxnet = rxnet; conn->security = &rxrpc_no_security; spin_lock_init(&conn->state_lock); conn->debug_id = atomic_inc_return(&rxrpc_debug_id); @@ -67,89 +91,55 @@ struct rxrpc_connection *rxrpc_alloc_connection(gfp_t gfp) * * The caller must be holding the RCU read lock. */ -struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, - struct sk_buff *skb, - struct rxrpc_peer **_peer) +struct rxrpc_connection *rxrpc_find_client_connection_rcu(struct rxrpc_local *local, + struct sockaddr_rxrpc *srx, + struct sk_buff *skb) { struct rxrpc_connection *conn; - struct rxrpc_conn_proto k; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - struct sockaddr_rxrpc srx; struct rxrpc_peer *peer; _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK); - if (rxrpc_extract_addr_from_skb(&srx, skb) < 0) - goto not_found; - - if (srx.transport.family != local->srx.transport.family && - (srx.transport.family == AF_INET && - local->srx.transport.family != AF_INET6)) { - pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n", - srx.transport.family, - local->srx.transport.family); + /* Look up client connections by connection ID alone as their + * IDs are unique for this machine. + */ + conn = idr_find(&local->conn_ids, sp->hdr.cid >> RXRPC_CIDSHIFT); + if (!conn || refcount_read(&conn->ref) == 0) { + _debug("no conn"); goto not_found; } - k.epoch = sp->hdr.epoch; - k.cid = sp->hdr.cid & RXRPC_CIDMASK; - - if (rxrpc_to_server(sp)) { - /* We need to look up service connections by the full protocol - * parameter set. We look up the peer first as an intermediate - * step and then the connection from the peer's tree. - */ - peer = rxrpc_lookup_peer_rcu(local, &srx); - if (!peer) - goto not_found; - *_peer = peer; - conn = rxrpc_find_service_conn_rcu(peer, skb); - if (!conn || refcount_read(&conn->ref) == 0) - goto not_found; - _leave(" = %p", conn); - return conn; - } else { - /* Look up client connections by connection ID alone as their - * IDs are unique for this machine. - */ - conn = idr_find(&rxrpc_client_conn_ids, - sp->hdr.cid >> RXRPC_CIDSHIFT); - if (!conn || refcount_read(&conn->ref) == 0) { - _debug("no conn"); - goto not_found; - } + if (conn->proto.epoch != sp->hdr.epoch || + conn->local != local) + goto not_found; - if (conn->proto.epoch != k.epoch || - conn->params.local != local) + peer = conn->peer; + switch (srx->transport.family) { + case AF_INET: + if (peer->srx.transport.sin.sin_port != + srx->transport.sin.sin_port || + peer->srx.transport.sin.sin_addr.s_addr != + srx->transport.sin.sin_addr.s_addr) goto not_found; - - peer = conn->params.peer; - switch (srx.transport.family) { - case AF_INET: - if (peer->srx.transport.sin.sin_port != - srx.transport.sin.sin_port || - peer->srx.transport.sin.sin_addr.s_addr != - srx.transport.sin.sin_addr.s_addr) - goto not_found; - break; + break; #ifdef CONFIG_AF_RXRPC_IPV6 - case AF_INET6: - if (peer->srx.transport.sin6.sin6_port != - srx.transport.sin6.sin6_port || - memcmp(&peer->srx.transport.sin6.sin6_addr, - &srx.transport.sin6.sin6_addr, - sizeof(struct in6_addr)) != 0) - goto not_found; - break; + case AF_INET6: + if (peer->srx.transport.sin6.sin6_port != + srx->transport.sin6.sin6_port || + memcmp(&peer->srx.transport.sin6.sin6_addr, + &srx->transport.sin6.sin6_addr, + sizeof(struct in6_addr)) != 0) + goto not_found; + break; #endif - default: - BUG(); - } - - _leave(" = %p", conn); - return conn; + default: + BUG(); } + _leave(" = %p", conn); + return conn; + not_found: _leave(" = NULL"); return NULL; @@ -168,14 +158,14 @@ void __rxrpc_disconnect_call(struct rxrpc_connection *conn, _enter("%d,%x", conn->debug_id, call->cid); - if (rcu_access_pointer(chan->call) == call) { + if (chan->call == call) { /* Save the result of the call so that we can repeat it if necessary * through the channel, whilst disposing of the actual call record. */ trace_rxrpc_disconnect_call(call); switch (call->completion) { case RXRPC_CALL_SUCCEEDED: - chan->last_seq = call->rx_hard_ack; + chan->last_seq = call->rx_highest_seq; chan->last_type = RXRPC_PACKET_TYPE_ACK; break; case RXRPC_CALL_LOCALLY_ABORTED: @@ -188,12 +178,9 @@ void __rxrpc_disconnect_call(struct rxrpc_connection *conn, break; } - /* Sync with rxrpc_conn_retransmit(). */ - smp_wmb(); chan->last_call = chan->call_id; chan->call_id = chan->call_counter; - - rcu_assign_pointer(chan->call, NULL); + chan->call = NULL; } _leave(""); @@ -207,96 +194,64 @@ void rxrpc_disconnect_call(struct rxrpc_call *call) { struct rxrpc_connection *conn = call->conn; - call->peer->cong_cwnd = call->cong_cwnd; - - if (!hlist_unhashed(&call->error_link)) { - spin_lock_bh(&call->peer->lock); - hlist_del_rcu(&call->error_link); - spin_unlock_bh(&call->peer->lock); - } - - if (rxrpc_is_client_call(call)) - return rxrpc_disconnect_client_call(conn->bundle, call); - - spin_lock(&conn->bundle->channel_lock); - __rxrpc_disconnect_call(conn, call); - spin_unlock(&conn->bundle->channel_lock); - set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); - conn->idle_timestamp = jiffies; -} - -/* - * Kill off a connection. - */ -void rxrpc_kill_connection(struct rxrpc_connection *conn) -{ - struct rxrpc_net *rxnet = conn->params.local->rxnet; + rxrpc_see_call(call, rxrpc_call_see_disconnected); - ASSERT(!rcu_access_pointer(conn->channels[0].call) && - !rcu_access_pointer(conn->channels[1].call) && - !rcu_access_pointer(conn->channels[2].call) && - !rcu_access_pointer(conn->channels[3].call)); - ASSERT(list_empty(&conn->cache_link)); + call->peer->cong_ssthresh = call->cong_ssthresh; - write_lock(&rxnet->conn_lock); - list_del_init(&conn->proc_link); - write_unlock(&rxnet->conn_lock); + if (!hlist_unhashed(&call->error_link)) { + spin_lock(&call->peer->lock); + hlist_del_init(&call->error_link); + spin_unlock(&call->peer->lock); + } - /* Drain the Rx queue. Note that even though we've unpublished, an - * incoming packet could still be being added to our Rx queue, so we - * will need to drain it again in the RCU cleanup handler. - */ - rxrpc_purge_queue(&conn->rx_queue); + if (rxrpc_is_client_call(call)) { + rxrpc_disconnect_client_call(call->bundle, call); + } else { + __rxrpc_disconnect_call(conn, call); + conn->idle_timestamp = jiffies; + if (atomic_dec_and_test(&conn->active)) + rxrpc_set_service_reap_timer(conn->rxnet, + jiffies + rxrpc_connection_expiry); + } - /* Leave final destruction to RCU. The connection processor work item - * must carry a ref on the connection to prevent us getting here whilst - * it is queued or running. - */ - call_rcu(&conn->rcu, rxrpc_destroy_connection); + rxrpc_put_call(call, rxrpc_call_put_io_thread); } /* * Queue a connection's work processor, getting a ref to pass to the work * queue. */ -bool rxrpc_queue_conn(struct rxrpc_connection *conn) +void rxrpc_queue_conn(struct rxrpc_connection *conn, enum rxrpc_conn_trace why) { - const void *here = __builtin_return_address(0); - int r; - - if (!__refcount_inc_not_zero(&conn->ref, &r)) - return false; - if (rxrpc_queue_work(&conn->processor)) - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_queued, r + 1, here); - else - rxrpc_put_connection(conn); - return true; + if (atomic_read(&conn->active) >= 0 && + rxrpc_queue_work(&conn->processor)) + rxrpc_see_connection(conn, why); } /* * Note the re-emergence of a connection. */ -void rxrpc_see_connection(struct rxrpc_connection *conn) +void rxrpc_see_connection(struct rxrpc_connection *conn, + enum rxrpc_conn_trace why) { - const void *here = __builtin_return_address(0); if (conn) { - int n = refcount_read(&conn->ref); + int r = refcount_read(&conn->ref); - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_seen, n, here); + trace_rxrpc_conn(conn->debug_id, r, why); } } /* * Get a ref on a connection. */ -struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *conn) +struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *conn, + enum rxrpc_conn_trace why) { - const void *here = __builtin_return_address(0); int r; __refcount_inc(&conn->ref, &r); - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_got, r, here); + trace_rxrpc_conn(conn->debug_id, r + 1, why); return conn; } @@ -304,14 +259,14 @@ struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *conn) * Try to get a ref on a connection. */ struct rxrpc_connection * -rxrpc_get_connection_maybe(struct rxrpc_connection *conn) +rxrpc_get_connection_maybe(struct rxrpc_connection *conn, + enum rxrpc_conn_trace why) { - const void *here = __builtin_return_address(0); int r; if (conn) { if (__refcount_inc_not_zero(&conn->ref, &r)) - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_got, r + 1, here); + trace_rxrpc_conn(conn->debug_id, r + 1, why); else conn = NULL; } @@ -329,49 +284,95 @@ static void rxrpc_set_service_reap_timer(struct rxrpc_net *rxnet, } /* - * Release a service connection + * destroy a virtual connection */ -void rxrpc_put_service_conn(struct rxrpc_connection *conn) +static void rxrpc_rcu_free_connection(struct rcu_head *rcu) { - const void *here = __builtin_return_address(0); - unsigned int debug_id = conn->debug_id; - int r; + struct rxrpc_connection *conn = + container_of(rcu, struct rxrpc_connection, rcu); + struct rxrpc_net *rxnet = conn->rxnet; - __refcount_dec(&conn->ref, &r); - trace_rxrpc_conn(debug_id, rxrpc_conn_put_service, r - 1, here); - if (r - 1 == 1) - rxrpc_set_service_reap_timer(conn->params.local->rxnet, - jiffies + rxrpc_connection_expiry); + _enter("{%d,u=%d}", conn->debug_id, refcount_read(&conn->ref)); + + trace_rxrpc_conn(conn->debug_id, refcount_read(&conn->ref), + rxrpc_conn_free); + kfree(conn); + + if (atomic_dec_and_test(&rxnet->nr_conns)) + wake_up_var(&rxnet->nr_conns); } /* - * destroy a virtual connection + * Clean up a dead connection. */ -static void rxrpc_destroy_connection(struct rcu_head *rcu) +static void rxrpc_clean_up_connection(struct work_struct *work) { struct rxrpc_connection *conn = - container_of(rcu, struct rxrpc_connection, rcu); + container_of(work, struct rxrpc_connection, destructor); + struct rxrpc_net *rxnet = conn->rxnet; - _enter("{%d,u=%d}", conn->debug_id, refcount_read(&conn->ref)); + ASSERT(!conn->channels[0].call && + !conn->channels[1].call && + !conn->channels[2].call && + !conn->channels[3].call); + ASSERT(list_empty(&conn->cache_link)); - ASSERTCMP(refcount_read(&conn->ref), ==, 0); + del_timer_sync(&conn->timer); + cancel_work_sync(&conn->processor); /* Processing may restart the timer */ + del_timer_sync(&conn->timer); - _net("DESTROY CONN %d", conn->debug_id); + write_lock(&rxnet->conn_lock); + list_del_init(&conn->proc_link); + write_unlock(&rxnet->conn_lock); - del_timer_sync(&conn->timer); rxrpc_purge_queue(&conn->rx_queue); + rxrpc_kill_client_conn(conn); + conn->security->clear(conn); - key_put(conn->params.key); - rxrpc_put_bundle(conn->bundle); - rxrpc_put_peer(conn->params.peer); + key_put(conn->key); + rxrpc_put_bundle(conn->bundle, rxrpc_bundle_put_conn); + rxrpc_put_peer(conn->peer, rxrpc_peer_put_conn); + rxrpc_put_local(conn->local, rxrpc_local_put_kill_conn); - if (atomic_dec_and_test(&conn->params.local->rxnet->nr_conns)) - wake_up_var(&conn->params.local->rxnet->nr_conns); - rxrpc_put_local(conn->params.local); + /* Drain the Rx queue. Note that even though we've unpublished, an + * incoming packet could still be being added to our Rx queue, so we + * will need to drain it again in the RCU cleanup handler. + */ + rxrpc_purge_queue(&conn->rx_queue); - kfree(conn); - _leave(""); + call_rcu(&conn->rcu, rxrpc_rcu_free_connection); +} + +/* + * Drop a ref on a connection. + */ +void rxrpc_put_connection(struct rxrpc_connection *conn, + enum rxrpc_conn_trace why) +{ + unsigned int debug_id; + bool dead; + int r; + + if (!conn) + return; + + debug_id = conn->debug_id; + dead = __refcount_dec_and_test(&conn->ref, &r); + trace_rxrpc_conn(debug_id, r - 1, why); + if (dead) { + del_timer(&conn->timer); + cancel_work(&conn->processor); + + if (in_softirq() || work_busy(&conn->processor) || + timer_pending(&conn->timer)) + /* Can't use the rxrpc workqueue as we need to cancel/flush + * something that may be running/waiting there. + */ + schedule_work(&conn->destructor); + else + rxrpc_clean_up_connection(&conn->destructor); + } } /* @@ -383,6 +384,7 @@ void rxrpc_service_connection_reaper(struct work_struct *work) struct rxrpc_net *rxnet = container_of(work, struct rxrpc_net, service_conn_reaper); unsigned long expire_at, earliest, idle_timestamp, now; + int active; LIST_HEAD(graveyard); @@ -393,20 +395,20 @@ void rxrpc_service_connection_reaper(struct work_struct *work) write_lock(&rxnet->conn_lock); list_for_each_entry_safe(conn, _p, &rxnet->service_conns, link) { - ASSERTCMP(refcount_read(&conn->ref), >, 0); - if (likely(refcount_read(&conn->ref) > 1)) + ASSERTCMP(atomic_read(&conn->active), >=, 0); + if (likely(atomic_read(&conn->active) > 0)) continue; if (conn->state == RXRPC_CONN_SERVICE_PREALLOC) continue; - if (rxnet->live && !conn->params.local->dead) { + if (rxnet->live && !conn->local->dead) { idle_timestamp = READ_ONCE(conn->idle_timestamp); expire_at = idle_timestamp + rxrpc_connection_expiry * HZ; - if (conn->params.local->service_closed) + if (conn->local->service_closed) expire_at = idle_timestamp + rxrpc_closed_conn_expiry * HZ; - _debug("reap CONN %d { u=%d,t=%ld }", - conn->debug_id, refcount_read(&conn->ref), + _debug("reap CONN %d { a=%d,t=%ld }", + conn->debug_id, atomic_read(&conn->active), (long)expire_at - (long)now); if (time_before(now, expire_at)) { @@ -416,12 +418,13 @@ void rxrpc_service_connection_reaper(struct work_struct *work) } } - /* The usage count sits at 1 whilst the object is unused on the - * list; we reduce that to 0 to make the object unavailable. + /* The activity count sits at 0 whilst the conn is unused on + * the list; we reduce that to -1 to make the conn unavailable. */ - if (!refcount_dec_if_one(&conn->ref)) + active = 0; + if (!atomic_try_cmpxchg(&conn->active, &active, -1)) continue; - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_reap_service, 0, NULL); + rxrpc_see_connection(conn, rxrpc_conn_see_reap_service); if (rxrpc_conn_is_client(conn)) BUG(); @@ -443,8 +446,8 @@ void rxrpc_service_connection_reaper(struct work_struct *work) link); list_del_init(&conn->link); - ASSERTCMP(refcount_read(&conn->ref), ==, 0); - rxrpc_kill_connection(conn); + ASSERTCMP(atomic_read(&conn->active), ==, -1); + rxrpc_put_connection(conn, rxrpc_conn_put_service_reaped); } _leave(""); @@ -462,7 +465,6 @@ void rxrpc_destroy_all_connections(struct rxrpc_net *rxnet) _enter(""); atomic_dec(&rxnet->nr_conns); - rxrpc_destroy_all_client_connections(rxnet); del_timer_sync(&rxnet->service_conn_reap_timer); rxrpc_queue_work(&rxnet->service_conn_reaper); diff --git a/net/rxrpc/conn_service.c b/net/rxrpc/conn_service.c index 6e6aa02c6f9e..f30323de82bd 100644 --- a/net/rxrpc/conn_service.c +++ b/net/rxrpc/conn_service.c @@ -11,7 +11,6 @@ static struct rxrpc_bundle rxrpc_service_dummy_bundle = { .ref = REFCOUNT_INIT(1), .debug_id = UINT_MAX, - .channel_lock = __SPIN_LOCK_UNLOCKED(&rxrpc_service_dummy_bundle.channel_lock), }; /* @@ -73,7 +72,7 @@ static void rxrpc_publish_service_conn(struct rxrpc_peer *peer, struct rxrpc_conn_proto k = conn->proto; struct rb_node **pp, *parent; - write_seqlock_bh(&peer->service_conn_lock); + write_seqlock(&peer->service_conn_lock); pp = &peer->service_conns.rb_node; parent = NULL; @@ -94,14 +93,14 @@ static void rxrpc_publish_service_conn(struct rxrpc_peer *peer, rb_insert_color(&conn->service_node, &peer->service_conns); conn_published: set_bit(RXRPC_CONN_IN_SERVICE_CONNS, &conn->flags); - write_sequnlock_bh(&peer->service_conn_lock); + write_sequnlock(&peer->service_conn_lock); _leave(" = %d [new]", conn->debug_id); return; found_extant_conn: if (refcount_read(&cursor->ref) == 0) goto replace_old_connection; - write_sequnlock_bh(&peer->service_conn_lock); + write_sequnlock(&peer->service_conn_lock); /* We should not be able to get here. rxrpc_incoming_connection() is * called in a non-reentrant context, so there can't be a race to * insert a new connection. @@ -125,7 +124,7 @@ replace_old_connection: struct rxrpc_connection *rxrpc_prealloc_service_connection(struct rxrpc_net *rxnet, gfp_t gfp) { - struct rxrpc_connection *conn = rxrpc_alloc_connection(gfp); + struct rxrpc_connection *conn = rxrpc_alloc_connection(rxnet, gfp); if (conn) { /* We maintain an extra ref on the connection whilst it is on @@ -133,7 +132,8 @@ struct rxrpc_connection *rxrpc_prealloc_service_connection(struct rxrpc_net *rxn */ conn->state = RXRPC_CONN_SERVICE_PREALLOC; refcount_set(&conn->ref, 2); - conn->bundle = rxrpc_get_bundle(&rxrpc_service_dummy_bundle); + conn->bundle = rxrpc_get_bundle(&rxrpc_service_dummy_bundle, + rxrpc_bundle_get_service_conn); atomic_inc(&rxnet->nr_conns); write_lock(&rxnet->conn_lock); @@ -141,9 +141,7 @@ struct rxrpc_connection *rxrpc_prealloc_service_connection(struct rxrpc_net *rxn list_add_tail(&conn->proc_link, &rxnet->conn_proc_list); write_unlock(&rxnet->conn_lock); - trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_service, - refcount_read(&conn->ref), - __builtin_return_address(0)); + rxrpc_see_connection(conn, rxrpc_conn_new_service); } return conn; @@ -164,7 +162,7 @@ void rxrpc_new_incoming_connection(struct rxrpc_sock *rx, conn->proto.epoch = sp->hdr.epoch; conn->proto.cid = sp->hdr.cid & RXRPC_CIDMASK; - conn->params.service_id = sp->hdr.serviceId; + conn->orig_service_id = sp->hdr.serviceId; conn->service_id = sp->hdr.serviceId; conn->security_ix = sp->hdr.securityIndex; conn->out_clientflag = 0; @@ -182,10 +180,10 @@ void rxrpc_new_incoming_connection(struct rxrpc_sock *rx, conn->service_id == rx->service_upgrade.from) conn->service_id = rx->service_upgrade.to; - /* Make the connection a target for incoming packets. */ - rxrpc_publish_service_conn(conn->params.peer, conn); + atomic_set(&conn->active, 1); - _net("CONNECTION new %d {%x}", conn->debug_id, conn->proto.cid); + /* Make the connection a target for incoming packets. */ + rxrpc_publish_service_conn(conn->peer, conn); } /* @@ -194,10 +192,10 @@ void rxrpc_new_incoming_connection(struct rxrpc_sock *rx, */ void rxrpc_unpublish_service_conn(struct rxrpc_connection *conn) { - struct rxrpc_peer *peer = conn->params.peer; + struct rxrpc_peer *peer = conn->peer; - write_seqlock_bh(&peer->service_conn_lock); + write_seqlock(&peer->service_conn_lock); if (test_and_clear_bit(RXRPC_CONN_IN_SERVICE_CONNS, &conn->flags)) rb_erase(&conn->service_node, &peer->service_conns); - write_sequnlock_bh(&peer->service_conn_lock); + write_sequnlock(&peer->service_conn_lock); } diff --git a/net/rxrpc/input.c b/net/rxrpc/input.c index 721d847ba92b..367927a99881 100644 --- a/net/rxrpc/input.c +++ b/net/rxrpc/input.c @@ -1,35 +1,18 @@ // SPDX-License-Identifier: GPL-2.0-or-later -/* RxRPC packet reception +/* Processing of received RxRPC packets * - * Copyright (C) 2007, 2016 Red Hat, Inc. All Rights Reserved. + * Copyright (C) 2020 Red Hat, Inc. All Rights Reserved. * Written by David Howells (dhowells@redhat.com) */ #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt -#include <linux/module.h> -#include <linux/net.h> -#include <linux/skbuff.h> -#include <linux/errqueue.h> -#include <linux/udp.h> -#include <linux/in.h> -#include <linux/in6.h> -#include <linux/icmp.h> -#include <linux/gfp.h> -#include <net/sock.h> -#include <net/af_rxrpc.h> -#include <net/ip.h> -#include <net/udp.h> -#include <net/net_namespace.h> #include "ar-internal.h" -static void rxrpc_proto_abort(const char *why, - struct rxrpc_call *call, rxrpc_seq_t seq) +static void rxrpc_proto_abort(struct rxrpc_call *call, rxrpc_seq_t seq, + enum rxrpc_abort_reason why) { - if (rxrpc_abort_call(why, call, seq, RX_PROTOCOL_ERROR, -EBADMSG)) { - set_bit(RXRPC_CALL_EV_ABORT, &call->events); - rxrpc_queue_call(call); - } + rxrpc_abort_call(call, seq, RX_PROTOCOL_ERROR, -EBADMSG, why); } /* @@ -46,7 +29,7 @@ static void rxrpc_congestion_management(struct rxrpc_call *call, bool resend = false; summary->flight_size = - (call->tx_top - call->tx_hard_ack) - summary->nr_acks; + (call->tx_top - call->acks_hard_ack) - summary->nr_acks; if (test_and_clear_bit(RXRPC_CALL_RETRANS_TIMEOUT, &call->flags)) { summary->retrans_timeo = true; @@ -74,7 +57,7 @@ static void rxrpc_congestion_management(struct rxrpc_call *call, switch (call->cong_mode) { case RXRPC_CALL_SLOW_START: - if (summary->nr_nacks > 0) + if (summary->saw_nacks) goto packet_loss_detected; if (summary->cumulative_acks > 0) cwnd += 1; @@ -85,7 +68,7 @@ static void rxrpc_congestion_management(struct rxrpc_call *call, goto out; case RXRPC_CALL_CONGEST_AVOIDANCE: - if (summary->nr_nacks > 0) + if (summary->saw_nacks) goto packet_loss_detected; /* We analyse the number of packets that get ACK'd per RTT @@ -104,7 +87,7 @@ static void rxrpc_congestion_management(struct rxrpc_call *call, goto out; case RXRPC_CALL_PACKET_LOSS: - if (summary->nr_nacks == 0) + if (!summary->saw_nacks) goto resume_normality; if (summary->new_low_nack) { @@ -142,7 +125,7 @@ static void rxrpc_congestion_management(struct rxrpc_call *call, } else { change = rxrpc_cong_progress; cwnd = call->cong_ssthresh; - if (summary->nr_nacks == 0) + if (!summary->saw_nacks) goto resume_normality; } goto out; @@ -164,13 +147,13 @@ resume_normality: out: cumulative_acks = 0; out_no_clear_ca: - if (cwnd >= RXRPC_RXTX_BUFF_SIZE - 1) - cwnd = RXRPC_RXTX_BUFF_SIZE - 1; + if (cwnd >= RXRPC_TX_MAX_WINDOW) + cwnd = RXRPC_TX_MAX_WINDOW; call->cong_cwnd = cwnd; call->cong_cumul_acks = cumulative_acks; trace_rxrpc_congest(call, summary, acked_serial, change); - if (resend && !test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) - rxrpc_queue_call(call); + if (resend) + rxrpc_resend(call, skb); return; packet_loss_detected: @@ -183,9 +166,8 @@ send_extra_data: /* Send some previously unsent DATA if we have some to advance the ACK * state. */ - if (call->rxtx_annotations[call->tx_top & RXRPC_RXTX_BUFF_MASK] & - RXRPC_TX_ANNO_LAST || - summary->nr_acks != call->tx_top - call->tx_hard_ack) { + if (test_bit(RXRPC_CALL_TX_LAST, &call->flags) || + summary->nr_acks != call->tx_top - call->acks_hard_ack) { call->cong_extra++; wake_up(&call->waitq); } @@ -193,58 +175,71 @@ send_extra_data: } /* + * Degrade the congestion window if we haven't transmitted a packet for >1RTT. + */ +void rxrpc_congestion_degrade(struct rxrpc_call *call) +{ + ktime_t rtt, now; + + if (call->cong_mode != RXRPC_CALL_SLOW_START && + call->cong_mode != RXRPC_CALL_CONGEST_AVOIDANCE) + return; + if (__rxrpc_call_state(call) == RXRPC_CALL_CLIENT_AWAIT_REPLY) + return; + + rtt = ns_to_ktime(call->peer->srtt_us * (1000 / 8)); + now = ktime_get_real(); + if (!ktime_before(ktime_add(call->tx_last_sent, rtt), now)) + return; + + trace_rxrpc_reset_cwnd(call, now); + rxrpc_inc_stat(call->rxnet, stat_tx_data_cwnd_reset); + call->tx_last_sent = now; + call->cong_mode = RXRPC_CALL_SLOW_START; + call->cong_ssthresh = max_t(unsigned int, call->cong_ssthresh, + call->cong_cwnd * 3 / 4); + call->cong_cwnd = max_t(unsigned int, call->cong_cwnd / 2, RXRPC_MIN_CWND); +} + +/* * Apply a hard ACK by advancing the Tx window. */ static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, struct rxrpc_ack_summary *summary) { - struct sk_buff *skb, *list = NULL; + struct rxrpc_txbuf *txb; bool rot_last = false; - int ix; - u8 annotation; - - if (call->acks_lowest_nak == call->tx_hard_ack) { - call->acks_lowest_nak = to; - } else if (before_eq(call->acks_lowest_nak, to)) { - summary->new_low_nack = true; - call->acks_lowest_nak = to; - } - - spin_lock(&call->lock); - - while (before(call->tx_hard_ack, to)) { - call->tx_hard_ack++; - ix = call->tx_hard_ack & RXRPC_RXTX_BUFF_MASK; - skb = call->rxtx_buffer[ix]; - annotation = call->rxtx_annotations[ix]; - rxrpc_see_skb(skb, rxrpc_skb_rotated); - call->rxtx_buffer[ix] = NULL; - call->rxtx_annotations[ix] = 0; - skb->next = list; - list = skb; - if (annotation & RXRPC_TX_ANNO_LAST) { + list_for_each_entry_rcu(txb, &call->tx_buffer, call_link, false) { + if (before_eq(txb->seq, call->acks_hard_ack)) + continue; + summary->nr_rot_new_acks++; + if (test_bit(RXRPC_TXBUF_LAST, &txb->flags)) { set_bit(RXRPC_CALL_TX_LAST, &call->flags); rot_last = true; } - if ((annotation & RXRPC_TX_ANNO_MASK) != RXRPC_TX_ANNO_ACK) - summary->nr_rot_new_acks++; + if (txb->seq == to) + break; } - spin_unlock(&call->lock); + if (rot_last) + set_bit(RXRPC_CALL_TX_ALL_ACKED, &call->flags); - trace_rxrpc_transmit(call, (rot_last ? - rxrpc_transmit_rotate_last : - rxrpc_transmit_rotate)); - wake_up(&call->waitq); + _enter("%x,%x,%x,%d", to, call->acks_hard_ack, call->tx_top, rot_last); - while (list) { - skb = list; - list = skb->next; - skb_mark_not_on_list(skb); - rxrpc_free_skb(skb, rxrpc_skb_freed); + if (call->acks_lowest_nak == call->acks_hard_ack) { + call->acks_lowest_nak = to; + } else if (after(to, call->acks_lowest_nak)) { + summary->new_low_nack = true; + call->acks_lowest_nak = to; } + smp_store_release(&call->acks_hard_ack, to); + + trace_rxrpc_txqueue(call, (rot_last ? + rxrpc_txqueue_rotate_last : + rxrpc_txqueue_rotate)); + wake_up(&call->waitq); return rot_last; } @@ -254,47 +249,34 @@ static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, * This occurs when we get an ACKALL packet, the first DATA packet of a reply, * or a final ACK packet. */ -static bool rxrpc_end_tx_phase(struct rxrpc_call *call, bool reply_begun, - const char *abort_why) +static void rxrpc_end_tx_phase(struct rxrpc_call *call, bool reply_begun, + enum rxrpc_abort_reason abort_why) { - unsigned int state; - ASSERT(test_bit(RXRPC_CALL_TX_LAST, &call->flags)); - write_lock(&call->state_lock); - - state = call->state; - switch (state) { + switch (__rxrpc_call_state(call)) { case RXRPC_CALL_CLIENT_SEND_REQUEST: case RXRPC_CALL_CLIENT_AWAIT_REPLY: - if (reply_begun) - call->state = state = RXRPC_CALL_CLIENT_RECV_REPLY; - else - call->state = state = RXRPC_CALL_CLIENT_AWAIT_REPLY; + if (reply_begun) { + rxrpc_set_call_state(call, RXRPC_CALL_CLIENT_RECV_REPLY); + trace_rxrpc_txqueue(call, rxrpc_txqueue_end); + break; + } + + rxrpc_set_call_state(call, RXRPC_CALL_CLIENT_AWAIT_REPLY); + trace_rxrpc_txqueue(call, rxrpc_txqueue_await_reply); break; case RXRPC_CALL_SERVER_AWAIT_ACK: - __rxrpc_call_completed(call); - state = call->state; + rxrpc_call_completed(call); + trace_rxrpc_txqueue(call, rxrpc_txqueue_end); break; default: - goto bad_state; + kdebug("end_tx %s", rxrpc_call_states[__rxrpc_call_state(call)]); + rxrpc_proto_abort(call, call->tx_top, abort_why); + break; } - - write_unlock(&call->state_lock); - if (state == RXRPC_CALL_CLIENT_AWAIT_REPLY) - trace_rxrpc_transmit(call, rxrpc_transmit_await_reply); - else - trace_rxrpc_transmit(call, rxrpc_transmit_end); - _leave(" = ok"); - return true; - -bad_state: - write_unlock(&call->state_lock); - kdebug("end_tx %s", rxrpc_call_states[call->state]); - rxrpc_proto_abort(abort_why, call, call->tx_top); - return false; } /* @@ -307,312 +289,328 @@ static bool rxrpc_receiving_reply(struct rxrpc_call *call) rxrpc_seq_t top = READ_ONCE(call->tx_top); if (call->ackr_reason) { - spin_lock_bh(&call->lock); - call->ackr_reason = 0; - spin_unlock_bh(&call->lock); now = jiffies; timo = now + MAX_JIFFY_OFFSET; - WRITE_ONCE(call->resend_at, timo); - WRITE_ONCE(call->ack_at, timo); + + WRITE_ONCE(call->delay_ack_at, timo); trace_rxrpc_timer(call, rxrpc_timer_init_for_reply, now); } if (!test_bit(RXRPC_CALL_TX_LAST, &call->flags)) { if (!rxrpc_rotate_tx_window(call, top, &summary)) { - rxrpc_proto_abort("TXL", call, top); + rxrpc_proto_abort(call, top, rxrpc_eproto_early_reply); return false; } } - if (!rxrpc_end_tx_phase(call, true, "ETD")) - return false; - call->tx_phase = false; + + rxrpc_end_tx_phase(call, true, rxrpc_eproto_unexpected_reply); return true; } /* - * Scan a data packet to validate its structure and to work out how many - * subpackets it contains. - * - * A jumbo packet is a collection of consecutive packets glued together with - * little headers between that indicate how to change the initial header for - * each subpacket. - * - * RXRPC_JUMBO_PACKET must be set on all but the last subpacket - and all but - * the last are RXRPC_JUMBO_DATALEN in size. The last subpacket may be of any - * size. + * End the packet reception phase. */ -static bool rxrpc_validate_data(struct sk_buff *skb) +static void rxrpc_end_rx_phase(struct rxrpc_call *call, rxrpc_serial_t serial) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - unsigned int offset = sizeof(struct rxrpc_wire_header); - unsigned int len = skb->len; - u8 flags = sp->hdr.flags; + rxrpc_seq_t whigh = READ_ONCE(call->rx_highest_seq); - for (;;) { - if (flags & RXRPC_REQUEST_ACK) - __set_bit(sp->nr_subpackets, sp->rx_req_ack); - sp->nr_subpackets++; + _enter("%d,%s", call->debug_id, rxrpc_call_states[__rxrpc_call_state(call)]); - if (!(flags & RXRPC_JUMBO_PACKET)) - break; + trace_rxrpc_receive(call, rxrpc_receive_end, 0, whigh); - if (len - offset < RXRPC_JUMBO_SUBPKTLEN) - goto protocol_error; - if (flags & RXRPC_LAST_PACKET) - goto protocol_error; - offset += RXRPC_JUMBO_DATALEN; - if (skb_copy_bits(skb, offset, &flags, 1) < 0) - goto protocol_error; - offset += sizeof(struct rxrpc_jumbo_header); - } + switch (__rxrpc_call_state(call)) { + case RXRPC_CALL_CLIENT_RECV_REPLY: + rxrpc_propose_delay_ACK(call, serial, rxrpc_propose_ack_terminal_ack); + rxrpc_call_completed(call); + break; - if (flags & RXRPC_LAST_PACKET) - sp->rx_flags |= RXRPC_SKB_INCL_LAST; - return true; + case RXRPC_CALL_SERVER_RECV_REQUEST: + rxrpc_set_call_state(call, RXRPC_CALL_SERVER_ACK_REQUEST); + call->expect_req_by = jiffies + MAX_JIFFY_OFFSET; + rxrpc_propose_delay_ACK(call, serial, rxrpc_propose_ack_processing_op); + break; -protocol_error: - return false; + default: + break; + } +} + +static void rxrpc_input_update_ack_window(struct rxrpc_call *call, + rxrpc_seq_t window, rxrpc_seq_t wtop) +{ + atomic64_set_release(&call->ackr_window, ((u64)wtop) << 32 | window); } /* - * Handle reception of a duplicate packet. - * - * We have to take care to avoid an attack here whereby we're given a series of - * jumbograms, each with a sequence number one before the preceding one and - * filled up to maximum UDP size. If they never send us the first packet in - * the sequence, they can cause us to have to hold on to around 2MiB of kernel - * space until the call times out. - * - * We limit the space usage by only accepting three duplicate jumbo packets per - * call. After that, we tell the other side we're no longer accepting jumbos - * (that information is encoded in the ACK packet). + * Push a DATA packet onto the Rx queue. */ -static void rxrpc_input_dup_data(struct rxrpc_call *call, rxrpc_seq_t seq, - bool is_jumbo, bool *_jumbo_bad) +static void rxrpc_input_queue_data(struct rxrpc_call *call, struct sk_buff *skb, + rxrpc_seq_t window, rxrpc_seq_t wtop, + enum rxrpc_receive_trace why) { - /* Discard normal packets that are duplicates. */ - if (is_jumbo) - return; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + bool last = sp->hdr.flags & RXRPC_LAST_PACKET; - /* Skip jumbo subpackets that are duplicates. When we've had three or - * more partially duplicate jumbo packets, we refuse to take any more - * jumbos for this call. - */ - if (!*_jumbo_bad) { - call->nr_jumbo_bad++; - *_jumbo_bad = true; - } + __skb_queue_tail(&call->recvmsg_queue, skb); + rxrpc_input_update_ack_window(call, window, wtop); + trace_rxrpc_receive(call, last ? why + 1 : why, sp->hdr.serial, sp->hdr.seq); + if (last) + rxrpc_end_rx_phase(call, sp->hdr.serial); } /* - * Process a DATA packet, adding the packet to the Rx ring. The caller's - * packet ref must be passed on or discarded. + * Process a DATA packet. */ -static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb) +static void rxrpc_input_data_one(struct rxrpc_call *call, struct sk_buff *skb, + bool *_notify) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - enum rxrpc_call_state state; - unsigned int j, nr_subpackets, nr_unacked = 0; - rxrpc_serial_t serial = sp->hdr.serial, ack_serial = serial; - rxrpc_seq_t seq0 = sp->hdr.seq, hard_ack; - bool immediate_ack = false, jumbo_bad = false; - u8 ack = 0; - - _enter("{%u,%u},{%u,%u}", - call->rx_hard_ack, call->rx_top, skb->len, seq0); - - _proto("Rx DATA %%%u { #%u f=%02x n=%u }", - sp->hdr.serial, seq0, sp->hdr.flags, sp->nr_subpackets); - - state = READ_ONCE(call->state); - if (state >= RXRPC_CALL_COMPLETE) { - rxrpc_free_skb(skb, rxrpc_skb_freed); - return; + struct sk_buff *oos; + rxrpc_serial_t serial = sp->hdr.serial; + u64 win = atomic64_read(&call->ackr_window); + rxrpc_seq_t window = lower_32_bits(win); + rxrpc_seq_t wtop = upper_32_bits(win); + rxrpc_seq_t wlimit = window + call->rx_winsize - 1; + rxrpc_seq_t seq = sp->hdr.seq; + bool last = sp->hdr.flags & RXRPC_LAST_PACKET; + int ack_reason = -1; + + rxrpc_inc_stat(call->rxnet, stat_rx_data); + if (sp->hdr.flags & RXRPC_REQUEST_ACK) + rxrpc_inc_stat(call->rxnet, stat_rx_data_reqack); + if (sp->hdr.flags & RXRPC_JUMBO_PACKET) + rxrpc_inc_stat(call->rxnet, stat_rx_data_jumbo); + + if (last) { + if (test_and_set_bit(RXRPC_CALL_RX_LAST, &call->flags) && + seq + 1 != wtop) + return rxrpc_proto_abort(call, seq, rxrpc_eproto_different_last); + } else { + if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && + after_eq(seq, wtop)) { + pr_warn("Packet beyond last: c=%x q=%x window=%x-%x wlimit=%x\n", + call->debug_id, seq, window, wtop, wlimit); + return rxrpc_proto_abort(call, seq, rxrpc_eproto_data_after_last); + } } - if (state == RXRPC_CALL_SERVER_RECV_REQUEST) { - unsigned long timo = READ_ONCE(call->next_req_timo); - unsigned long now, expect_req_by; + if (after(seq, call->rx_highest_seq)) + call->rx_highest_seq = seq; - if (timo) { - now = jiffies; - expect_req_by = now + timo; - WRITE_ONCE(call->expect_req_by, expect_req_by); - rxrpc_reduce_call_timer(call, expect_req_by, now, - rxrpc_timer_set_for_idle); - } + trace_rxrpc_rx_data(call->debug_id, seq, serial, sp->hdr.flags); + + if (before(seq, window)) { + ack_reason = RXRPC_ACK_DUPLICATE; + goto send_ack; + } + if (after(seq, wlimit)) { + ack_reason = RXRPC_ACK_EXCEEDS_WINDOW; + goto send_ack; } - spin_lock(&call->input_lock); + /* Queue the packet. */ + if (seq == window) { + rxrpc_seq_t reset_from; + bool reset_sack = false; - /* Received data implicitly ACKs all of the request packets we sent - * when we're acting as a client. - */ - if ((state == RXRPC_CALL_CLIENT_SEND_REQUEST || - state == RXRPC_CALL_CLIENT_AWAIT_REPLY) && - !rxrpc_receiving_reply(call)) - goto unlock; - - hard_ack = READ_ONCE(call->rx_hard_ack); - - nr_subpackets = sp->nr_subpackets; - if (nr_subpackets > 1) { - if (call->nr_jumbo_bad > 3) { - ack = RXRPC_ACK_NOSPACE; - ack_serial = serial; - goto ack; - } - } + if (sp->hdr.flags & RXRPC_REQUEST_ACK) + ack_reason = RXRPC_ACK_REQUESTED; + /* Send an immediate ACK if we fill in a hole */ + else if (!skb_queue_empty(&call->rx_oos_queue)) + ack_reason = RXRPC_ACK_DELAY; + else + atomic_inc_return(&call->ackr_nr_unacked); - for (j = 0; j < nr_subpackets; j++) { - rxrpc_serial_t serial = sp->hdr.serial + j; - rxrpc_seq_t seq = seq0 + j; - unsigned int ix = seq & RXRPC_RXTX_BUFF_MASK; - bool terminal = (j == nr_subpackets - 1); - bool last = terminal && (sp->rx_flags & RXRPC_SKB_INCL_LAST); - u8 flags, annotation = j; - - _proto("Rx DATA+%u %%%u { #%x t=%u l=%u }", - j, serial, seq, terminal, last); - - if (last) { - if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && - seq != call->rx_top) { - rxrpc_proto_abort("LSN", call, seq); - goto unlock; - } - } else { - if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && - after_eq(seq, call->rx_top)) { - rxrpc_proto_abort("LSA", call, seq); - goto unlock; + window++; + if (after(window, wtop)) + wtop = window; + + rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg); + + spin_lock(&call->recvmsg_queue.lock); + rxrpc_input_queue_data(call, skb, window, wtop, rxrpc_receive_queue); + *_notify = true; + + while ((oos = skb_peek(&call->rx_oos_queue))) { + struct rxrpc_skb_priv *osp = rxrpc_skb(oos); + + if (after(osp->hdr.seq, window)) + break; + + __skb_unlink(oos, &call->rx_oos_queue); + last = osp->hdr.flags & RXRPC_LAST_PACKET; + seq = osp->hdr.seq; + if (!reset_sack) { + reset_from = seq; + reset_sack = true; } - } - flags = 0; - if (last) - flags |= RXRPC_LAST_PACKET; - if (!terminal) - flags |= RXRPC_JUMBO_PACKET; - if (test_bit(j, sp->rx_req_ack)) - flags |= RXRPC_REQUEST_ACK; - trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation); - - if (before_eq(seq, hard_ack)) { - ack = RXRPC_ACK_DUPLICATE; - ack_serial = serial; - continue; + window++; + rxrpc_input_queue_data(call, oos, window, wtop, + rxrpc_receive_queue_oos); } - if (call->rxtx_buffer[ix]) { - rxrpc_input_dup_data(call, seq, nr_subpackets > 1, - &jumbo_bad); - if (ack != RXRPC_ACK_DUPLICATE) { - ack = RXRPC_ACK_DUPLICATE; - ack_serial = serial; - } - immediate_ack = true; - continue; + spin_unlock(&call->recvmsg_queue.lock); + + if (reset_sack) { + do { + call->ackr_sack_table[reset_from % RXRPC_SACK_SIZE] = 0; + } while (reset_from++, before(reset_from, window)); } + } else { + bool keep = false; - if (after(seq, hard_ack + call->rx_winsize)) { - ack = RXRPC_ACK_EXCEEDS_WINDOW; - ack_serial = serial; - if (flags & RXRPC_JUMBO_PACKET) { - if (!jumbo_bad) { - call->nr_jumbo_bad++; - jumbo_bad = true; - } - } + ack_reason = RXRPC_ACK_OUT_OF_SEQUENCE; - goto ack; + if (!call->ackr_sack_table[seq % RXRPC_SACK_SIZE]) { + call->ackr_sack_table[seq % RXRPC_SACK_SIZE] = 1; + keep = 1; } - if (flags & RXRPC_REQUEST_ACK && !ack) { - ack = RXRPC_ACK_REQUESTED; - ack_serial = serial; + if (after(seq + 1, wtop)) { + wtop = seq + 1; + rxrpc_input_update_ack_window(call, window, wtop); } - if (after(seq0, call->ackr_highest_seq)) - call->ackr_highest_seq = seq0; + if (!keep) { + ack_reason = RXRPC_ACK_DUPLICATE; + goto send_ack; + } - /* Queue the packet. We use a couple of memory barriers here as need - * to make sure that rx_top is perceived to be set after the buffer - * pointer and that the buffer pointer is set after the annotation and - * the skb data. - * - * Barriers against rxrpc_recvmsg_data() and rxrpc_rotate_rx_window() - * and also rxrpc_fill_out_ack(). - */ - if (!terminal) - rxrpc_get_skb(skb, rxrpc_skb_got); - call->rxtx_annotations[ix] = annotation; - smp_wmb(); - call->rxtx_buffer[ix] = skb; - if (after(seq, call->rx_top)) { - smp_store_release(&call->rx_top, seq); - } else if (before(seq, call->rx_top)) { - /* Send an immediate ACK if we fill in a hole */ - if (!ack) { - ack = RXRPC_ACK_DELAY; - ack_serial = serial; + skb_queue_walk(&call->rx_oos_queue, oos) { + struct rxrpc_skb_priv *osp = rxrpc_skb(oos); + + if (after(osp->hdr.seq, seq)) { + rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg_oos); + __skb_queue_before(&call->rx_oos_queue, oos, skb); + goto oos_queued; } - immediate_ack = true; } - if (terminal) { - /* From this point on, we're not allowed to touch the - * packet any longer as its ref now belongs to the Rx - * ring. - */ - skb = NULL; - sp = NULL; - } + rxrpc_get_skb(skb, rxrpc_skb_get_to_recvmsg_oos); + __skb_queue_tail(&call->rx_oos_queue, skb); + oos_queued: + trace_rxrpc_receive(call, last ? rxrpc_receive_oos_last : rxrpc_receive_oos, + sp->hdr.serial, sp->hdr.seq); + } - nr_unacked++; +send_ack: + if (ack_reason >= 0) + rxrpc_send_ACK(call, ack_reason, serial, + rxrpc_propose_ack_input_data); + else + rxrpc_propose_delay_ACK(call, serial, + rxrpc_propose_ack_input_data); +} - if (last) { - set_bit(RXRPC_CALL_RX_LAST, &call->flags); - if (!ack) { - ack = RXRPC_ACK_DELAY; - ack_serial = serial; - } - trace_rxrpc_receive(call, rxrpc_receive_queue_last, serial, seq); - } else { - trace_rxrpc_receive(call, rxrpc_receive_queue, serial, seq); +/* + * Split a jumbo packet and file the bits separately. + */ +static bool rxrpc_input_split_jumbo(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_jumbo_header jhdr; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb), *jsp; + struct sk_buff *jskb; + unsigned int offset = sizeof(struct rxrpc_wire_header); + unsigned int len = skb->len - offset; + bool notify = false; + + while (sp->hdr.flags & RXRPC_JUMBO_PACKET) { + if (len < RXRPC_JUMBO_SUBPKTLEN) + goto protocol_error; + if (sp->hdr.flags & RXRPC_LAST_PACKET) + goto protocol_error; + if (skb_copy_bits(skb, offset + RXRPC_JUMBO_DATALEN, + &jhdr, sizeof(jhdr)) < 0) + goto protocol_error; + + jskb = skb_clone(skb, GFP_NOFS); + if (!jskb) { + kdebug("couldn't clone"); + return false; } + rxrpc_new_skb(jskb, rxrpc_skb_new_jumbo_subpacket); + jsp = rxrpc_skb(jskb); + jsp->offset = offset; + jsp->len = RXRPC_JUMBO_DATALEN; + rxrpc_input_data_one(call, jskb, ¬ify); + rxrpc_free_skb(jskb, rxrpc_skb_put_jumbo_subpacket); + + sp->hdr.flags = jhdr.flags; + sp->hdr._rsvd = ntohs(jhdr._rsvd); + sp->hdr.seq++; + sp->hdr.serial++; + offset += RXRPC_JUMBO_SUBPKTLEN; + len -= RXRPC_JUMBO_SUBPKTLEN; + } - if (after_eq(seq, call->rx_expect_next)) { - if (after(seq, call->rx_expect_next)) { - _net("OOS %u > %u", seq, call->rx_expect_next); - ack = RXRPC_ACK_OUT_OF_SEQUENCE; - ack_serial = serial; - } - call->rx_expect_next = seq + 1; + sp->offset = offset; + sp->len = len; + rxrpc_input_data_one(call, skb, ¬ify); + if (notify) { + trace_rxrpc_notify_socket(call->debug_id, sp->hdr.serial); + rxrpc_notify_socket(call); + } + return true; + +protocol_error: + return false; +} + +/* + * Process a DATA packet, adding the packet to the Rx ring. The caller's + * packet ref must be passed on or discarded. + */ +static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + rxrpc_serial_t serial = sp->hdr.serial; + rxrpc_seq_t seq0 = sp->hdr.seq; + + _enter("{%llx,%x},{%u,%x}", + atomic64_read(&call->ackr_window), call->rx_highest_seq, + skb->len, seq0); + + if (__rxrpc_call_is_complete(call)) + return; + + switch (__rxrpc_call_state(call)) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + case RXRPC_CALL_CLIENT_AWAIT_REPLY: + /* Received data implicitly ACKs all of the request + * packets we sent when we're acting as a client. + */ + if (!rxrpc_receiving_reply(call)) + goto out_notify; + break; + + case RXRPC_CALL_SERVER_RECV_REQUEST: { + unsigned long timo = READ_ONCE(call->next_req_timo); + unsigned long now, expect_req_by; + + if (timo) { + now = jiffies; + expect_req_by = now + timo; + WRITE_ONCE(call->expect_req_by, expect_req_by); + rxrpc_reduce_call_timer(call, expect_req_by, now, + rxrpc_timer_set_for_idle); } - if (!ack) - ack_serial = serial; + break; } -ack: - if (atomic_add_return(nr_unacked, &call->ackr_nr_unacked) > 2 && !ack) - ack = RXRPC_ACK_IDLE; + default: + break; + } - if (ack) - rxrpc_propose_ACK(call, ack, ack_serial, - immediate_ack, true, - rxrpc_propose_ack_input_data); - else - rxrpc_propose_ACK(call, RXRPC_ACK_DELAY, serial, - false, true, - rxrpc_propose_ack_input_data); + if (!rxrpc_input_split_jumbo(call, skb)) { + rxrpc_proto_abort(call, sp->hdr.seq, rxrpc_badmsg_bad_jumbo); + goto out_notify; + } + skb = NULL; +out_notify: trace_rxrpc_notify_socket(call->debug_id, serial); rxrpc_notify_socket(call); - -unlock: - spin_unlock(&call->input_lock); - rxrpc_free_skb(skb, rxrpc_skb_freed); _leave(" [queued]"); } @@ -671,55 +669,6 @@ static void rxrpc_complete_rtt_probe(struct rxrpc_call *call, } /* - * Process the response to a ping that we sent to find out if we lost an ACK. - * - * If we got back a ping response that indicates a lower tx_top than what we - * had at the time of the ping transmission, we adjudge all the DATA packets - * sent between the response tx_top and the ping-time tx_top to have been lost. - */ -static void rxrpc_input_check_for_lost_ack(struct rxrpc_call *call) -{ - rxrpc_seq_t top, bottom, seq; - bool resend = false; - - spin_lock_bh(&call->lock); - - bottom = call->tx_hard_ack + 1; - top = call->acks_lost_top; - if (before(bottom, top)) { - for (seq = bottom; before_eq(seq, top); seq++) { - int ix = seq & RXRPC_RXTX_BUFF_MASK; - u8 annotation = call->rxtx_annotations[ix]; - u8 anno_type = annotation & RXRPC_TX_ANNO_MASK; - - if (anno_type != RXRPC_TX_ANNO_UNACK) - continue; - annotation &= ~RXRPC_TX_ANNO_MASK; - annotation |= RXRPC_TX_ANNO_RETRANS; - call->rxtx_annotations[ix] = annotation; - resend = true; - } - } - - spin_unlock_bh(&call->lock); - - if (resend && !test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) - rxrpc_queue_call(call); -} - -/* - * Process a ping response. - */ -static void rxrpc_input_ping_response(struct rxrpc_call *call, - ktime_t resp_time, - rxrpc_serial_t acked_serial, - rxrpc_serial_t ack_serial) -{ - if (acked_serial == call->acks_lost_ping) - rxrpc_input_check_for_lost_ack(call); -} - -/* * Process the extra information that may be appended to an ACK packet */ static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb, @@ -731,13 +680,8 @@ static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb, bool wake = false; u32 rwind = ntohl(ackinfo->rwind); - _proto("Rx ACK %%%u Info { rx=%u max=%u rwin=%u jm=%u }", - sp->hdr.serial, - ntohl(ackinfo->rxMTU), ntohl(ackinfo->maxMTU), - rwind, ntohl(ackinfo->jumbo_max)); - - if (rwind > RXRPC_RXTX_BUFF_SIZE - 1) - rwind = RXRPC_RXTX_BUFF_SIZE - 1; + if (rwind > RXRPC_TX_MAX_WINDOW) + rwind = RXRPC_TX_MAX_WINDOW; if (call->tx_winsize != rwind) { if (rwind > call->tx_winsize) wake = true; @@ -752,11 +696,10 @@ static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb, peer = call->peer; if (mtu < peer->maxdata) { - spin_lock_bh(&peer->lock); + spin_lock(&peer->lock); peer->maxdata = mtu; peer->mtu = mtu + peer->hdrsize; - spin_unlock_bh(&peer->lock); - _net("Net MTU %u (maxdata %u)", peer->mtu, peer->maxdata); + spin_unlock(&peer->lock); } if (wake) @@ -776,40 +719,19 @@ static void rxrpc_input_soft_acks(struct rxrpc_call *call, u8 *acks, rxrpc_seq_t seq, int nr_acks, struct rxrpc_ack_summary *summary) { - int ix; - u8 annotation, anno_type; - - for (; nr_acks > 0; nr_acks--, seq++) { - ix = seq & RXRPC_RXTX_BUFF_MASK; - annotation = call->rxtx_annotations[ix]; - anno_type = annotation & RXRPC_TX_ANNO_MASK; - annotation &= ~RXRPC_TX_ANNO_MASK; - switch (*acks++) { - case RXRPC_ACK_TYPE_ACK: + unsigned int i; + + for (i = 0; i < nr_acks; i++) { + if (acks[i] == RXRPC_ACK_TYPE_ACK) { summary->nr_acks++; - if (anno_type == RXRPC_TX_ANNO_ACK) - continue; summary->nr_new_acks++; - call->rxtx_annotations[ix] = - RXRPC_TX_ANNO_ACK | annotation; - break; - case RXRPC_ACK_TYPE_NACK: - if (!summary->nr_nacks && - call->acks_lowest_nak != seq) { - call->acks_lowest_nak = seq; + } else { + if (!summary->saw_nacks && + call->acks_lowest_nak != seq + i) { + call->acks_lowest_nak = seq + i; summary->new_low_nack = true; } - summary->nr_nacks++; - if (anno_type == RXRPC_TX_ANNO_NAK) - continue; - summary->nr_new_nacks++; - if (anno_type == RXRPC_TX_ANNO_RETRANS) - continue; - call->rxtx_annotations[ix] = - RXRPC_TX_ANNO_NAK | annotation; - break; - default: - return rxrpc_proto_abort("SFT", call, 0); + summary->saw_nacks = true; } } } @@ -851,12 +773,9 @@ static bool rxrpc_is_ack_valid(struct rxrpc_call *call, static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) { struct rxrpc_ack_summary summary = { 0 }; + struct rxrpc_ackpacket ack; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - union { - struct rxrpc_ackpacket ack; - struct rxrpc_ackinfo info; - u8 acks[RXRPC_MAXACKS]; - } buf; + struct rxrpc_ackinfo info; rxrpc_serial_t ack_serial, acked_serial; rxrpc_seq_t first_soft_ack, hard_ack, prev_pkt; int nr_acks, offset, ioffset; @@ -864,29 +783,26 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) _enter(""); offset = sizeof(struct rxrpc_wire_header); - if (skb_copy_bits(skb, offset, &buf.ack, sizeof(buf.ack)) < 0) { - _debug("extraction failure"); - return rxrpc_proto_abort("XAK", call, 0); - } - offset += sizeof(buf.ack); + if (skb_copy_bits(skb, offset, &ack, sizeof(ack)) < 0) + return rxrpc_proto_abort(call, 0, rxrpc_badmsg_short_ack); + offset += sizeof(ack); ack_serial = sp->hdr.serial; - acked_serial = ntohl(buf.ack.serial); - first_soft_ack = ntohl(buf.ack.firstPacket); - prev_pkt = ntohl(buf.ack.previousPacket); + acked_serial = ntohl(ack.serial); + first_soft_ack = ntohl(ack.firstPacket); + prev_pkt = ntohl(ack.previousPacket); hard_ack = first_soft_ack - 1; - nr_acks = buf.ack.nAcks; - summary.ack_reason = (buf.ack.reason < RXRPC_ACK__INVALID ? - buf.ack.reason : RXRPC_ACK__INVALID); + nr_acks = ack.nAcks; + summary.ack_reason = (ack.reason < RXRPC_ACK__INVALID ? + ack.reason : RXRPC_ACK__INVALID); trace_rxrpc_rx_ack(call, ack_serial, acked_serial, first_soft_ack, prev_pkt, summary.ack_reason, nr_acks); + rxrpc_inc_stat(call->rxnet, stat_rx_acks[ack.reason]); - switch (buf.ack.reason) { + switch (ack.reason) { case RXRPC_ACK_PING_RESPONSE: - rxrpc_input_ping_response(call, skb->tstamp, acked_serial, - ack_serial); rxrpc_complete_rtt_probe(call, skb->tstamp, acked_serial, ack_serial, rxrpc_rtt_rx_ping_response); break; @@ -901,22 +817,19 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) break; } - if (buf.ack.reason == RXRPC_ACK_PING) { - _proto("Rx ACK %%%u PING Request", ack_serial); - rxrpc_propose_ACK(call, RXRPC_ACK_PING_RESPONSE, - ack_serial, true, true, - rxrpc_propose_ack_respond_to_ping); + if (ack.reason == RXRPC_ACK_PING) { + rxrpc_send_ACK(call, RXRPC_ACK_PING_RESPONSE, ack_serial, + rxrpc_propose_ack_respond_to_ping); } else if (sp->hdr.flags & RXRPC_REQUEST_ACK) { - rxrpc_propose_ACK(call, RXRPC_ACK_REQUESTED, - ack_serial, true, true, - rxrpc_propose_ack_respond_to_ack); + rxrpc_send_ACK(call, RXRPC_ACK_REQUESTED, ack_serial, + rxrpc_propose_ack_respond_to_ack); } /* If we get an EXCEEDS_WINDOW ACK from the server, it probably * indicates that the client address changed due to NAT. The server * lost the call because it switched to a different peer. */ - if (unlikely(buf.ack.reason == RXRPC_ACK_EXCEEDS_WINDOW) && + if (unlikely(ack.reason == RXRPC_ACK_EXCEEDS_WINDOW) && first_soft_ack == 1 && prev_pkt == 0 && rxrpc_is_client_call(call)) { @@ -929,10 +842,10 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) * indicate a change of address. However, we can retransmit the call * if we still have it buffered to the beginning. */ - if (unlikely(buf.ack.reason == RXRPC_ACK_OUT_OF_SEQUENCE) && + if (unlikely(ack.reason == RXRPC_ACK_OUT_OF_SEQUENCE) && first_soft_ack == 1 && prev_pkt == 0 && - call->tx_hard_ack == 0 && + call->acks_hard_ack == 0 && rxrpc_is_client_call(call)) { rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, 0, -ENETRESET); @@ -947,83 +860,73 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) return; } - buf.info.rxMTU = 0; + info.rxMTU = 0; ioffset = offset + nr_acks + 3; - if (skb->len >= ioffset + sizeof(buf.info) && - skb_copy_bits(skb, ioffset, &buf.info, sizeof(buf.info)) < 0) - return rxrpc_proto_abort("XAI", call, 0); + if (skb->len >= ioffset + sizeof(info) && + skb_copy_bits(skb, ioffset, &info, sizeof(info)) < 0) + return rxrpc_proto_abort(call, 0, rxrpc_badmsg_short_ack_info); - spin_lock(&call->input_lock); + if (nr_acks > 0) + skb_condense(skb); - /* Discard any out-of-order or duplicate ACKs (inside lock). */ - if (!rxrpc_is_ack_valid(call, first_soft_ack, prev_pkt)) { - trace_rxrpc_rx_discard_ack(call->debug_id, ack_serial, - first_soft_ack, call->acks_first_seq, - prev_pkt, call->acks_prev_seq); - goto out; - } call->acks_latest_ts = skb->tstamp; - call->acks_first_seq = first_soft_ack; call->acks_prev_seq = prev_pkt; + switch (ack.reason) { + case RXRPC_ACK_PING: + break; + default: + if (after(acked_serial, call->acks_highest_serial)) + call->acks_highest_serial = acked_serial; + break; + } + /* Parse rwind and mtu sizes if provided. */ - if (buf.info.rxMTU) - rxrpc_input_ackinfo(call, skb, &buf.info); + if (info.rxMTU) + rxrpc_input_ackinfo(call, skb, &info); - if (first_soft_ack == 0) { - rxrpc_proto_abort("AK0", call, 0); - goto out; - } + if (first_soft_ack == 0) + return rxrpc_proto_abort(call, 0, rxrpc_eproto_ackr_zero); /* Ignore ACKs unless we are or have just been transmitting. */ - switch (READ_ONCE(call->state)) { + switch (__rxrpc_call_state(call)) { case RXRPC_CALL_CLIENT_SEND_REQUEST: case RXRPC_CALL_CLIENT_AWAIT_REPLY: case RXRPC_CALL_SERVER_SEND_REPLY: case RXRPC_CALL_SERVER_AWAIT_ACK: break; default: - goto out; + return; } - if (before(hard_ack, call->tx_hard_ack) || - after(hard_ack, call->tx_top)) { - rxrpc_proto_abort("AKW", call, 0); - goto out; - } - if (nr_acks > call->tx_top - hard_ack) { - rxrpc_proto_abort("AKN", call, 0); - goto out; - } + if (before(hard_ack, call->acks_hard_ack) || + after(hard_ack, call->tx_top)) + return rxrpc_proto_abort(call, 0, rxrpc_eproto_ackr_outside_window); + if (nr_acks > call->tx_top - hard_ack) + return rxrpc_proto_abort(call, 0, rxrpc_eproto_ackr_sack_overflow); - if (after(hard_ack, call->tx_hard_ack)) { + if (after(hard_ack, call->acks_hard_ack)) { if (rxrpc_rotate_tx_window(call, hard_ack, &summary)) { - rxrpc_end_tx_phase(call, false, "ETA"); - goto out; + rxrpc_end_tx_phase(call, false, rxrpc_eproto_unexpected_ack); + return; } } if (nr_acks > 0) { - if (skb_copy_bits(skb, offset, buf.acks, nr_acks) < 0) { - rxrpc_proto_abort("XSA", call, 0); - goto out; - } - rxrpc_input_soft_acks(call, buf.acks, first_soft_ack, nr_acks, - &summary); + if (offset > (int)skb->len - nr_acks) + return rxrpc_proto_abort(call, 0, rxrpc_eproto_ackr_short_sack); + rxrpc_input_soft_acks(call, skb->data + offset, first_soft_ack, + nr_acks, &summary); } - if (call->rxtx_annotations[call->tx_top & RXRPC_RXTX_BUFF_MASK] & - RXRPC_TX_ANNO_LAST && + if (test_bit(RXRPC_CALL_TX_LAST, &call->flags) && summary.nr_acks == call->tx_top - hard_ack && rxrpc_is_client_call(call)) - rxrpc_propose_ACK(call, RXRPC_ACK_PING, ack_serial, - false, true, - rxrpc_propose_ack_ping_for_lost_reply); + rxrpc_propose_ping(call, ack_serial, + rxrpc_propose_ack_ping_for_lost_reply); rxrpc_congestion_management(call, skb, &summary, acked_serial); -out: - spin_unlock(&call->input_lock); } /* @@ -1032,16 +935,9 @@ out: static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb) { struct rxrpc_ack_summary summary = { 0 }; - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - - _proto("Rx ACKALL %%%u", sp->hdr.serial); - - spin_lock(&call->input_lock); if (rxrpc_rotate_tx_window(call, call->tx_top, &summary)) - rxrpc_end_tx_phase(call, false, "ETL"); - - spin_unlock(&call->input_lock); + rxrpc_end_tx_phase(call, false, rxrpc_eproto_unexpected_ackall); } /* @@ -1050,35 +946,30 @@ static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb) static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - __be32 wtmp; - u32 abort_code = RX_CALL_DEAD; - _enter(""); - - if (skb->len >= 4 && - skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), - &wtmp, sizeof(wtmp)) >= 0) - abort_code = ntohl(wtmp); - - trace_rxrpc_rx_abort(call, sp->hdr.serial, abort_code); - - _proto("Rx ABORT %%%u { %x }", sp->hdr.serial, abort_code); + trace_rxrpc_rx_abort(call, sp->hdr.serial, skb->priority); rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, - abort_code, -ECONNABORTED); + skb->priority, -ECONNABORTED); } /* * Process an incoming call packet. */ -static void rxrpc_input_call_packet(struct rxrpc_call *call, - struct sk_buff *skb) +void rxrpc_input_call_packet(struct rxrpc_call *call, struct sk_buff *skb) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); unsigned long timo; _enter("%p,%p", call, skb); + if (sp->hdr.serviceId != call->dest_srx.srx_service) + call->dest_srx.srx_service = sp->hdr.serviceId; + if ((int)sp->hdr.serial - (int)call->rx_serial > 0) + call->rx_serial = sp->hdr.serial; + if (!test_bit(RXRPC_CALL_RX_HEARD, &call->flags)) + set_bit(RXRPC_CALL_RX_HEARD, &call->flags); + timo = READ_ONCE(call->next_rx_timo); if (timo) { unsigned long now = jiffies, expect_rx_by; @@ -1091,37 +982,27 @@ static void rxrpc_input_call_packet(struct rxrpc_call *call, switch (sp->hdr.type) { case RXRPC_PACKET_TYPE_DATA: - rxrpc_input_data(call, skb); - goto no_free; + return rxrpc_input_data(call, skb); case RXRPC_PACKET_TYPE_ACK: - rxrpc_input_ack(call, skb); - break; + return rxrpc_input_ack(call, skb); case RXRPC_PACKET_TYPE_BUSY: - _proto("Rx BUSY %%%u", sp->hdr.serial); - /* Just ignore BUSY packets from the server; the retry and * lifespan timers will take care of business. BUSY packets * from the client don't make sense. */ - break; + return; case RXRPC_PACKET_TYPE_ABORT: - rxrpc_input_abort(call, skb); - break; + return rxrpc_input_abort(call, skb); case RXRPC_PACKET_TYPE_ACKALL: - rxrpc_input_ackall(call, skb); - break; + return rxrpc_input_ackall(call, skb); default: break; } - - rxrpc_free_skb(skb, rxrpc_skb_freed); -no_free: - _leave(""); } /* @@ -1130,373 +1011,20 @@ no_free: * * TODO: If callNumber > call_id + 1, renegotiate security. */ -static void rxrpc_input_implicit_end_call(struct rxrpc_sock *rx, - struct rxrpc_connection *conn, - struct rxrpc_call *call) +void rxrpc_implicit_end_call(struct rxrpc_call *call, struct sk_buff *skb) { - switch (READ_ONCE(call->state)) { + switch (__rxrpc_call_state(call)) { case RXRPC_CALL_SERVER_AWAIT_ACK: rxrpc_call_completed(call); fallthrough; case RXRPC_CALL_COMPLETE: break; default: - if (rxrpc_abort_call("IMP", call, 0, RX_CALL_DEAD, -ESHUTDOWN)) { - set_bit(RXRPC_CALL_EV_ABORT, &call->events); - rxrpc_queue_call(call); - } + rxrpc_abort_call(call, 0, RX_CALL_DEAD, -ESHUTDOWN, + rxrpc_eproto_improper_term); trace_rxrpc_improper_term(call); break; } - spin_lock(&rx->incoming_lock); - __rxrpc_disconnect_call(conn, call); - spin_unlock(&rx->incoming_lock); -} - -/* - * post connection-level events to the connection - * - this includes challenges, responses, some aborts and call terminal packet - * retransmission. - */ -static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn, - struct sk_buff *skb) -{ - _enter("%p,%p", conn, skb); - - skb_queue_tail(&conn->rx_queue, skb); - rxrpc_queue_conn(conn); -} - -/* - * post endpoint-level events to the local endpoint - * - this includes debug and version messages - */ -static void rxrpc_post_packet_to_local(struct rxrpc_local *local, - struct sk_buff *skb) -{ - _enter("%p,%p", local, skb); - - if (rxrpc_get_local_maybe(local)) { - skb_queue_tail(&local->event_queue, skb); - rxrpc_queue_local(local); - } else { - rxrpc_free_skb(skb, rxrpc_skb_freed); - } -} - -/* - * put a packet up for transport-level abort - */ -static void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb) -{ - if (rxrpc_get_local_maybe(local)) { - skb_queue_tail(&local->reject_queue, skb); - rxrpc_queue_local(local); - } else { - rxrpc_free_skb(skb, rxrpc_skb_freed); - } -} - -/* - * Extract the wire header from a packet and translate the byte order. - */ -static noinline -int rxrpc_extract_header(struct rxrpc_skb_priv *sp, struct sk_buff *skb) -{ - struct rxrpc_wire_header whdr; - - /* dig out the RxRPC connection details */ - if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0) { - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, - tracepoint_string("bad_hdr")); - return -EBADMSG; - } - - memset(sp, 0, sizeof(*sp)); - sp->hdr.epoch = ntohl(whdr.epoch); - sp->hdr.cid = ntohl(whdr.cid); - sp->hdr.callNumber = ntohl(whdr.callNumber); - sp->hdr.seq = ntohl(whdr.seq); - sp->hdr.serial = ntohl(whdr.serial); - sp->hdr.flags = whdr.flags; - sp->hdr.type = whdr.type; - sp->hdr.userStatus = whdr.userStatus; - sp->hdr.securityIndex = whdr.securityIndex; - sp->hdr._rsvd = ntohs(whdr._rsvd); - sp->hdr.serviceId = ntohs(whdr.serviceId); - return 0; -} - -/* - * handle data received on the local endpoint - * - may be called in interrupt context - * - * [!] Note that as this is called from the encap_rcv hook, the socket is not - * held locked by the caller and nothing prevents sk_user_data on the UDP from - * being cleared in the middle of processing this function. - * - * Called with the RCU read lock held from the IP layer via UDP. - */ -int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb) -{ - struct rxrpc_local *local = rcu_dereference_sk_user_data(udp_sk); - struct rxrpc_connection *conn; - struct rxrpc_channel *chan; - struct rxrpc_call *call = NULL; - struct rxrpc_skb_priv *sp; - struct rxrpc_peer *peer = NULL; - struct rxrpc_sock *rx = NULL; - unsigned int channel; - - _enter("%p", udp_sk); - - if (unlikely(!local)) { - kfree_skb(skb); - return 0; - } - if (skb->tstamp == 0) - skb->tstamp = ktime_get_real(); - - rxrpc_new_skb(skb, rxrpc_skb_received); - - skb_pull(skb, sizeof(struct udphdr)); - - /* The UDP protocol already released all skb resources; - * we are free to add our own data there. - */ - sp = rxrpc_skb(skb); - - /* dig out the RxRPC connection details */ - if (rxrpc_extract_header(sp, skb) < 0) - goto bad_message; - - if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) { - static int lose; - if ((lose++ & 7) == 7) { - trace_rxrpc_rx_lose(sp); - rxrpc_free_skb(skb, rxrpc_skb_lost); - return 0; - } - } - - if (skb->tstamp == 0) - skb->tstamp = ktime_get_real(); - trace_rxrpc_rx_packet(sp); - - switch (sp->hdr.type) { - case RXRPC_PACKET_TYPE_VERSION: - if (rxrpc_to_client(sp)) - goto discard; - rxrpc_post_packet_to_local(local, skb); - goto out; - - case RXRPC_PACKET_TYPE_BUSY: - if (rxrpc_to_server(sp)) - goto discard; - fallthrough; - case RXRPC_PACKET_TYPE_ACK: - case RXRPC_PACKET_TYPE_ACKALL: - if (sp->hdr.callNumber == 0) - goto bad_message; - fallthrough; - case RXRPC_PACKET_TYPE_ABORT: - break; - - case RXRPC_PACKET_TYPE_DATA: - if (sp->hdr.callNumber == 0 || - sp->hdr.seq == 0) - goto bad_message; - if (!rxrpc_validate_data(skb)) - goto bad_message; - - /* Unshare the packet so that it can be modified for in-place - * decryption. - */ - if (sp->hdr.securityIndex != 0) { - struct sk_buff *nskb = skb_unshare(skb, GFP_ATOMIC); - if (!nskb) { - rxrpc_eaten_skb(skb, rxrpc_skb_unshared_nomem); - goto out; - } - - if (nskb != skb) { - rxrpc_eaten_skb(skb, rxrpc_skb_received); - skb = nskb; - rxrpc_new_skb(skb, rxrpc_skb_unshared); - sp = rxrpc_skb(skb); - } - } - break; - - case RXRPC_PACKET_TYPE_CHALLENGE: - if (rxrpc_to_server(sp)) - goto discard; - break; - case RXRPC_PACKET_TYPE_RESPONSE: - if (rxrpc_to_client(sp)) - goto discard; - break; - - /* Packet types 9-11 should just be ignored. */ - case RXRPC_PACKET_TYPE_PARAMS: - case RXRPC_PACKET_TYPE_10: - case RXRPC_PACKET_TYPE_11: - goto discard; - - default: - _proto("Rx Bad Packet Type %u", sp->hdr.type); - goto bad_message; - } - - if (sp->hdr.serviceId == 0) - goto bad_message; - - if (rxrpc_to_server(sp)) { - /* Weed out packets to services we're not offering. Packets - * that would begin a call are explicitly rejected and the rest - * are just discarded. - */ - rx = rcu_dereference(local->service); - if (!rx || (sp->hdr.serviceId != rx->srx.srx_service && - sp->hdr.serviceId != rx->second_service)) { - if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA && - sp->hdr.seq == 1) - goto unsupported_service; - goto discard; - } - } - - conn = rxrpc_find_connection_rcu(local, skb, &peer); - if (conn) { - if (sp->hdr.securityIndex != conn->security_ix) - goto wrong_security; - - if (sp->hdr.serviceId != conn->service_id) { - int old_id; - - if (!test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags)) - goto reupgrade; - old_id = cmpxchg(&conn->service_id, conn->params.service_id, - sp->hdr.serviceId); - - if (old_id != conn->params.service_id && - old_id != sp->hdr.serviceId) - goto reupgrade; - } - - if (sp->hdr.callNumber == 0) { - /* Connection-level packet */ - _debug("CONN %p {%d}", conn, conn->debug_id); - rxrpc_post_packet_to_conn(conn, skb); - goto out; - } - - if ((int)sp->hdr.serial - (int)conn->hi_serial > 0) - conn->hi_serial = sp->hdr.serial; - - /* Call-bound packets are routed by connection channel. */ - channel = sp->hdr.cid & RXRPC_CHANNELMASK; - chan = &conn->channels[channel]; - - /* Ignore really old calls */ - if (sp->hdr.callNumber < chan->last_call) - goto discard; - - if (sp->hdr.callNumber == chan->last_call) { - if (chan->call || - sp->hdr.type == RXRPC_PACKET_TYPE_ABORT) - goto discard; - - /* For the previous service call, if completed - * successfully, we discard all further packets. - */ - if (rxrpc_conn_is_service(conn) && - chan->last_type == RXRPC_PACKET_TYPE_ACK) - goto discard; - - /* But otherwise we need to retransmit the final packet - * from data cached in the connection record. - */ - if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA) - trace_rxrpc_rx_data(chan->call_debug_id, - sp->hdr.seq, - sp->hdr.serial, - sp->hdr.flags, 0); - rxrpc_post_packet_to_conn(conn, skb); - goto out; - } - - call = rcu_dereference(chan->call); - - if (sp->hdr.callNumber > chan->call_id) { - if (rxrpc_to_client(sp)) - goto reject_packet; - if (call) - rxrpc_input_implicit_end_call(rx, conn, call); - call = NULL; - } - - if (call) { - if (sp->hdr.serviceId != call->service_id) - call->service_id = sp->hdr.serviceId; - if ((int)sp->hdr.serial - (int)call->rx_serial > 0) - call->rx_serial = sp->hdr.serial; - if (!test_bit(RXRPC_CALL_RX_HEARD, &call->flags)) - set_bit(RXRPC_CALL_RX_HEARD, &call->flags); - } - } - - if (!call || refcount_read(&call->ref) == 0) { - if (rxrpc_to_client(sp) || - sp->hdr.type != RXRPC_PACKET_TYPE_DATA) - goto bad_message; - if (sp->hdr.seq != 1) - goto discard; - call = rxrpc_new_incoming_call(local, rx, skb); - if (!call) - goto reject_packet; - } - - /* Process a call packet; this either discards or passes on the ref - * elsewhere. - */ - rxrpc_input_call_packet(call, skb); - goto out; - -discard: - rxrpc_free_skb(skb, rxrpc_skb_freed); -out: - trace_rxrpc_rx_done(0, 0); - return 0; - -wrong_security: - trace_rxrpc_abort(0, "SEC", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RXKADINCONSISTENCY, EBADMSG); - skb->priority = RXKADINCONSISTENCY; - goto post_abort; - -unsupported_service: - trace_rxrpc_abort(0, "INV", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_INVALID_OPERATION, EOPNOTSUPP); - skb->priority = RX_INVALID_OPERATION; - goto post_abort; - -reupgrade: - trace_rxrpc_abort(0, "UPG", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_PROTOCOL_ERROR, EBADMSG); - goto protocol_error; - -bad_message: - trace_rxrpc_abort(0, "BAD", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_PROTOCOL_ERROR, EBADMSG); -protocol_error: - skb->priority = RX_PROTOCOL_ERROR; -post_abort: - skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; -reject_packet: - trace_rxrpc_rx_done(skb->mark, skb->priority); - rxrpc_reject_packet(local, skb); - _leave(" [badmsg]"); - return 0; + rxrpc_input_call_event(call, skb); } diff --git a/net/rxrpc/insecure.c b/net/rxrpc/insecure.c index 9aae99d67833..34353b6e584b 100644 --- a/net/rxrpc/insecure.c +++ b/net/rxrpc/insecure.c @@ -25,16 +25,16 @@ static int none_how_much_data(struct rxrpc_call *call, size_t remain, return 0; } -static int none_secure_packet(struct rxrpc_call *call, struct sk_buff *skb, - size_t data_size) +static int none_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) { return 0; } -static int none_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int offset, unsigned int len, - rxrpc_seq_t seq, u16 expected_cksum) +static int none_verify_packet(struct rxrpc_call *call, struct sk_buff *skb) { + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + sp->flags |= RXRPC_RX_VERIFIED; return 0; } @@ -42,31 +42,18 @@ static void none_free_call_crypto(struct rxrpc_call *call) { } -static void none_locate_data(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int *_offset, unsigned int *_len) -{ -} - static int none_respond_to_challenge(struct rxrpc_connection *conn, - struct sk_buff *skb, - u32 *_abort_code) + struct sk_buff *skb) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, - tracepoint_string("chall_none")); - return -EPROTO; + return rxrpc_abort_conn(conn, skb, RX_PROTOCOL_ERROR, -EPROTO, + rxrpc_eproto_rxnull_challenge); } static int none_verify_response(struct rxrpc_connection *conn, - struct sk_buff *skb, - u32 *_abort_code) + struct sk_buff *skb) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, - tracepoint_string("resp_none")); - return -EPROTO; + return rxrpc_abort_conn(conn, skb, RX_PROTOCOL_ERROR, -EPROTO, + rxrpc_eproto_rxnull_response); } static void none_clear(struct rxrpc_connection *conn) @@ -95,7 +82,6 @@ const struct rxrpc_security rxrpc_no_security = { .how_much_data = none_how_much_data, .secure_packet = none_secure_packet, .verify_packet = none_verify_packet, - .locate_data = none_locate_data, .respond_to_challenge = none_respond_to_challenge, .verify_response = none_verify_response, .clear = none_clear, diff --git a/net/rxrpc/io_thread.c b/net/rxrpc/io_thread.c new file mode 100644 index 000000000000..9e9dfb2fc559 --- /dev/null +++ b/net/rxrpc/io_thread.c @@ -0,0 +1,514 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC packet reception + * + * Copyright (C) 2007, 2016, 2022 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include "ar-internal.h" + +static int rxrpc_input_packet_on_conn(struct rxrpc_connection *conn, + struct sockaddr_rxrpc *peer_srx, + struct sk_buff *skb); + +/* + * handle data received on the local endpoint + * - may be called in interrupt context + * + * [!] Note that as this is called from the encap_rcv hook, the socket is not + * held locked by the caller and nothing prevents sk_user_data on the UDP from + * being cleared in the middle of processing this function. + * + * Called with the RCU read lock held from the IP layer via UDP. + */ +int rxrpc_encap_rcv(struct sock *udp_sk, struct sk_buff *skb) +{ + struct rxrpc_local *local = rcu_dereference_sk_user_data(udp_sk); + + if (unlikely(!local)) { + kfree_skb(skb); + return 0; + } + if (skb->tstamp == 0) + skb->tstamp = ktime_get_real(); + + skb->mark = RXRPC_SKB_MARK_PACKET; + rxrpc_new_skb(skb, rxrpc_skb_new_encap_rcv); + skb_queue_tail(&local->rx_queue, skb); + rxrpc_wake_up_io_thread(local); + return 0; +} + +/* + * Handle an error received on the local endpoint. + */ +void rxrpc_error_report(struct sock *sk) +{ + struct rxrpc_local *local; + struct sk_buff *skb; + + rcu_read_lock(); + local = rcu_dereference_sk_user_data(sk); + if (unlikely(!local)) { + rcu_read_unlock(); + return; + } + + while ((skb = skb_dequeue(&sk->sk_error_queue))) { + skb->mark = RXRPC_SKB_MARK_ERROR; + rxrpc_new_skb(skb, rxrpc_skb_new_error_report); + skb_queue_tail(&local->rx_queue, skb); + } + + rxrpc_wake_up_io_thread(local); + rcu_read_unlock(); +} + +/* + * Directly produce an abort from a packet. + */ +bool rxrpc_direct_abort(struct sk_buff *skb, enum rxrpc_abort_reason why, + s32 abort_code, int err) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + trace_rxrpc_abort(0, why, sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + abort_code, err); + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; + skb->priority = abort_code; + return false; +} + +static bool rxrpc_bad_message(struct sk_buff *skb, enum rxrpc_abort_reason why) +{ + return rxrpc_direct_abort(skb, why, RX_PROTOCOL_ERROR, -EBADMSG); +} + +#define just_discard true + +/* + * Process event packets targeted at a local endpoint. + */ +static bool rxrpc_input_version(struct rxrpc_local *local, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + char v; + + _enter(""); + + rxrpc_see_skb(skb, rxrpc_skb_see_version); + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), &v, 1) >= 0) { + if (v == 0) + rxrpc_send_version_request(local, &sp->hdr, skb); + } + + return true; +} + +/* + * Extract the wire header from a packet and translate the byte order. + */ +static bool rxrpc_extract_header(struct rxrpc_skb_priv *sp, + struct sk_buff *skb) +{ + struct rxrpc_wire_header whdr; + + /* dig out the RxRPC connection details */ + if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0) + return rxrpc_bad_message(skb, rxrpc_badmsg_short_hdr); + + memset(sp, 0, sizeof(*sp)); + sp->hdr.epoch = ntohl(whdr.epoch); + sp->hdr.cid = ntohl(whdr.cid); + sp->hdr.callNumber = ntohl(whdr.callNumber); + sp->hdr.seq = ntohl(whdr.seq); + sp->hdr.serial = ntohl(whdr.serial); + sp->hdr.flags = whdr.flags; + sp->hdr.type = whdr.type; + sp->hdr.userStatus = whdr.userStatus; + sp->hdr.securityIndex = whdr.securityIndex; + sp->hdr._rsvd = ntohs(whdr._rsvd); + sp->hdr.serviceId = ntohs(whdr.serviceId); + return true; +} + +/* + * Extract the abort code from an ABORT packet and stash it in skb->priority. + */ +static bool rxrpc_extract_abort(struct sk_buff *skb) +{ + __be32 wtmp; + + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + &wtmp, sizeof(wtmp)) < 0) + return false; + skb->priority = ntohl(wtmp); + return true; +} + +/* + * Process packets received on the local endpoint + */ +static bool rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb) +{ + struct rxrpc_connection *conn; + struct sockaddr_rxrpc peer_srx; + struct rxrpc_skb_priv *sp; + struct rxrpc_peer *peer = NULL; + struct sk_buff *skb = *_skb; + bool ret = false; + + skb_pull(skb, sizeof(struct udphdr)); + + sp = rxrpc_skb(skb); + + /* dig out the RxRPC connection details */ + if (!rxrpc_extract_header(sp, skb)) + return just_discard; + + if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) { + static int lose; + if ((lose++ & 7) == 7) { + trace_rxrpc_rx_lose(sp); + return just_discard; + } + } + + trace_rxrpc_rx_packet(sp); + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_VERSION: + if (rxrpc_to_client(sp)) + return just_discard; + return rxrpc_input_version(local, skb); + + case RXRPC_PACKET_TYPE_BUSY: + if (rxrpc_to_server(sp)) + return just_discard; + fallthrough; + case RXRPC_PACKET_TYPE_ACK: + case RXRPC_PACKET_TYPE_ACKALL: + if (sp->hdr.callNumber == 0) + return rxrpc_bad_message(skb, rxrpc_badmsg_zero_call); + break; + case RXRPC_PACKET_TYPE_ABORT: + if (!rxrpc_extract_abort(skb)) + return just_discard; /* Just discard if malformed */ + break; + + case RXRPC_PACKET_TYPE_DATA: + if (sp->hdr.callNumber == 0) + return rxrpc_bad_message(skb, rxrpc_badmsg_zero_call); + if (sp->hdr.seq == 0) + return rxrpc_bad_message(skb, rxrpc_badmsg_zero_seq); + + /* Unshare the packet so that it can be modified for in-place + * decryption. + */ + if (sp->hdr.securityIndex != 0) { + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) { + rxrpc_eaten_skb(*_skb, rxrpc_skb_eaten_by_unshare_nomem); + *_skb = NULL; + return just_discard; + } + + if (skb != *_skb) { + rxrpc_eaten_skb(*_skb, rxrpc_skb_eaten_by_unshare); + *_skb = skb; + rxrpc_new_skb(skb, rxrpc_skb_new_unshared); + sp = rxrpc_skb(skb); + } + } + break; + + case RXRPC_PACKET_TYPE_CHALLENGE: + if (rxrpc_to_server(sp)) + return just_discard; + break; + case RXRPC_PACKET_TYPE_RESPONSE: + if (rxrpc_to_client(sp)) + return just_discard; + break; + + /* Packet types 9-11 should just be ignored. */ + case RXRPC_PACKET_TYPE_PARAMS: + case RXRPC_PACKET_TYPE_10: + case RXRPC_PACKET_TYPE_11: + return just_discard; + + default: + return rxrpc_bad_message(skb, rxrpc_badmsg_unsupported_packet); + } + + if (sp->hdr.serviceId == 0) + return rxrpc_bad_message(skb, rxrpc_badmsg_zero_service); + + if (WARN_ON_ONCE(rxrpc_extract_addr_from_skb(&peer_srx, skb) < 0)) + return just_discard; /* Unsupported address type. */ + + if (peer_srx.transport.family != local->srx.transport.family && + (peer_srx.transport.family == AF_INET && + local->srx.transport.family != AF_INET6)) { + pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n", + peer_srx.transport.family, + local->srx.transport.family); + return just_discard; /* Wrong address type. */ + } + + if (rxrpc_to_client(sp)) { + rcu_read_lock(); + conn = rxrpc_find_client_connection_rcu(local, &peer_srx, skb); + conn = rxrpc_get_connection_maybe(conn, rxrpc_conn_get_call_input); + rcu_read_unlock(); + if (!conn) + return rxrpc_protocol_error(skb, rxrpc_eproto_no_client_conn); + + ret = rxrpc_input_packet_on_conn(conn, &peer_srx, skb); + rxrpc_put_connection(conn, rxrpc_conn_put_call_input); + return ret; + } + + /* We need to look up service connections by the full protocol + * parameter set. We look up the peer first as an intermediate step + * and then the connection from the peer's tree. + */ + rcu_read_lock(); + + peer = rxrpc_lookup_peer_rcu(local, &peer_srx); + if (!peer) { + rcu_read_unlock(); + return rxrpc_new_incoming_call(local, NULL, NULL, &peer_srx, skb); + } + + conn = rxrpc_find_service_conn_rcu(peer, skb); + conn = rxrpc_get_connection_maybe(conn, rxrpc_conn_get_call_input); + if (conn) { + rcu_read_unlock(); + ret = rxrpc_input_packet_on_conn(conn, &peer_srx, skb); + rxrpc_put_connection(conn, rxrpc_conn_put_call_input); + return ret; + } + + peer = rxrpc_get_peer_maybe(peer, rxrpc_peer_get_input); + rcu_read_unlock(); + + ret = rxrpc_new_incoming_call(local, peer, NULL, &peer_srx, skb); + rxrpc_put_peer(peer, rxrpc_peer_put_input); + return ret; +} + +/* + * Deal with a packet that's associated with an extant connection. + */ +static int rxrpc_input_packet_on_conn(struct rxrpc_connection *conn, + struct sockaddr_rxrpc *peer_srx, + struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxrpc_channel *chan; + struct rxrpc_call *call = NULL; + unsigned int channel; + bool ret; + + if (sp->hdr.securityIndex != conn->security_ix) + return rxrpc_direct_abort(skb, rxrpc_eproto_wrong_security, + RXKADINCONSISTENCY, -EBADMSG); + + if (sp->hdr.serviceId != conn->service_id) { + int old_id; + + if (!test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags)) + return rxrpc_protocol_error(skb, rxrpc_eproto_reupgrade); + + old_id = cmpxchg(&conn->service_id, conn->orig_service_id, + sp->hdr.serviceId); + if (old_id != conn->orig_service_id && + old_id != sp->hdr.serviceId) + return rxrpc_protocol_error(skb, rxrpc_eproto_bad_upgrade); + } + + if (after(sp->hdr.serial, conn->hi_serial)) + conn->hi_serial = sp->hdr.serial; + + /* It's a connection-level packet if the call number is 0. */ + if (sp->hdr.callNumber == 0) + return rxrpc_input_conn_packet(conn, skb); + + /* Call-bound packets are routed by connection channel. */ + channel = sp->hdr.cid & RXRPC_CHANNELMASK; + chan = &conn->channels[channel]; + + /* Ignore really old calls */ + if (sp->hdr.callNumber < chan->last_call) + return just_discard; + + if (sp->hdr.callNumber == chan->last_call) { + if (chan->call || + sp->hdr.type == RXRPC_PACKET_TYPE_ABORT) + return just_discard; + + /* For the previous service call, if completed successfully, we + * discard all further packets. + */ + if (rxrpc_conn_is_service(conn) && + chan->last_type == RXRPC_PACKET_TYPE_ACK) + return just_discard; + + /* But otherwise we need to retransmit the final packet from + * data cached in the connection record. + */ + if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA) + trace_rxrpc_rx_data(chan->call_debug_id, + sp->hdr.seq, + sp->hdr.serial, + sp->hdr.flags); + rxrpc_conn_retransmit_call(conn, skb, channel); + return just_discard; + } + + call = rxrpc_try_get_call(chan->call, rxrpc_call_get_input); + + if (sp->hdr.callNumber > chan->call_id) { + if (rxrpc_to_client(sp)) { + rxrpc_put_call(call, rxrpc_call_put_input); + return rxrpc_protocol_error(skb, + rxrpc_eproto_unexpected_implicit_end); + } + + if (call) { + rxrpc_implicit_end_call(call, skb); + rxrpc_put_call(call, rxrpc_call_put_input); + call = NULL; + } + } + + if (!call) { + if (rxrpc_to_client(sp)) + return rxrpc_protocol_error(skb, rxrpc_eproto_no_client_call); + return rxrpc_new_incoming_call(conn->local, conn->peer, conn, + peer_srx, skb); + } + + ret = rxrpc_input_call_event(call, skb); + rxrpc_put_call(call, rxrpc_call_put_input); + return ret; +} + +/* + * I/O and event handling thread. + */ +int rxrpc_io_thread(void *data) +{ + struct rxrpc_connection *conn; + struct sk_buff_head rx_queue; + struct rxrpc_local *local = data; + struct rxrpc_call *call; + struct sk_buff *skb; + bool should_stop; + + complete(&local->io_thread_ready); + + skb_queue_head_init(&rx_queue); + + set_user_nice(current, MIN_NICE); + + for (;;) { + rxrpc_inc_stat(local->rxnet, stat_io_loop); + + /* Deal with connections that want immediate attention. */ + conn = list_first_entry_or_null(&local->conn_attend_q, + struct rxrpc_connection, + attend_link); + if (conn) { + spin_lock_bh(&local->lock); + list_del_init(&conn->attend_link); + spin_unlock_bh(&local->lock); + + rxrpc_input_conn_event(conn, NULL); + rxrpc_put_connection(conn, rxrpc_conn_put_poke); + continue; + } + + if (test_and_clear_bit(RXRPC_CLIENT_CONN_REAP_TIMER, + &local->client_conn_flags)) + rxrpc_discard_expired_client_conns(local); + + /* Deal with calls that want immediate attention. */ + if ((call = list_first_entry_or_null(&local->call_attend_q, + struct rxrpc_call, + attend_link))) { + spin_lock_bh(&local->lock); + list_del_init(&call->attend_link); + spin_unlock_bh(&local->lock); + + trace_rxrpc_call_poked(call); + rxrpc_input_call_event(call, NULL); + rxrpc_put_call(call, rxrpc_call_put_poke); + continue; + } + + if (!list_empty(&local->new_client_calls)) + rxrpc_connect_client_calls(local); + + /* Process received packets and errors. */ + if ((skb = __skb_dequeue(&rx_queue))) { + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + switch (skb->mark) { + case RXRPC_SKB_MARK_PACKET: + skb->priority = 0; + if (!rxrpc_input_packet(local, &skb)) + rxrpc_reject_packet(local, skb); + trace_rxrpc_rx_done(skb->mark, skb->priority); + rxrpc_free_skb(skb, rxrpc_skb_put_input); + break; + case RXRPC_SKB_MARK_ERROR: + rxrpc_input_error(local, skb); + rxrpc_free_skb(skb, rxrpc_skb_put_error_report); + break; + case RXRPC_SKB_MARK_SERVICE_CONN_SECURED: + rxrpc_input_conn_event(sp->conn, skb); + rxrpc_put_connection(sp->conn, rxrpc_conn_put_poke); + rxrpc_free_skb(skb, rxrpc_skb_put_conn_secured); + break; + default: + WARN_ON_ONCE(1); + rxrpc_free_skb(skb, rxrpc_skb_put_unknown); + break; + } + continue; + } + + if (!skb_queue_empty(&local->rx_queue)) { + spin_lock_irq(&local->rx_queue.lock); + skb_queue_splice_tail_init(&local->rx_queue, &rx_queue); + spin_unlock_irq(&local->rx_queue.lock); + continue; + } + + set_current_state(TASK_INTERRUPTIBLE); + should_stop = kthread_should_stop(); + if (!skb_queue_empty(&local->rx_queue) || + !list_empty(&local->call_attend_q) || + !list_empty(&local->conn_attend_q) || + !list_empty(&local->new_client_calls) || + test_bit(RXRPC_CLIENT_CONN_REAP_TIMER, + &local->client_conn_flags)) { + __set_current_state(TASK_RUNNING); + continue; + } + + if (should_stop) + break; + schedule(); + } + + __set_current_state(TASK_RUNNING); + rxrpc_see_local(local, rxrpc_local_stop); + rxrpc_destroy_local(local); + local->io_thread = NULL; + rxrpc_see_local(local, rxrpc_local_stopped); + return 0; +} diff --git a/net/rxrpc/key.c b/net/rxrpc/key.c index 8d2073e0e3da..8d53aded09c4 100644 --- a/net/rxrpc/key.c +++ b/net/rxrpc/key.c @@ -513,7 +513,7 @@ int rxrpc_get_server_data_key(struct rxrpc_connection *conn, if (ret < 0) goto error; - conn->params.key = key; + conn->key = key; _leave(" = 0 [%d]", key_serial(key)); return 0; @@ -602,7 +602,8 @@ static long rxrpc_read(const struct key *key, } _debug("token[%u]: toksize=%u", ntoks, toksize); - ASSERTCMP(toksize, <=, AFSTOKEN_LENGTH_MAX); + if (WARN_ON(toksize > AFSTOKEN_LENGTH_MAX)) + return -EIO; toksizes[ntoks++] = toksize; size += toksize + 4; /* each token has a length word */ @@ -679,8 +680,9 @@ static long rxrpc_read(const struct key *key, return -ENOPKG; } - ASSERTCMP((unsigned long)xdr - (unsigned long)oldxdr, ==, - toksize); + if (WARN_ON((unsigned long)xdr - (unsigned long)oldxdr == + toksize)) + return -EIO; } #undef ENCODE_STR @@ -688,8 +690,10 @@ static long rxrpc_read(const struct key *key, #undef ENCODE64 #undef ENCODE - ASSERTCMP(tok, ==, ntoks); - ASSERTCMP((char __user *) xdr - buffer, ==, size); + if (WARN_ON(tok != ntoks)) + return -EIO; + if (WARN_ON((unsigned long)xdr - (unsigned long)buffer != size)) + return -EIO; _leave(" = %zu", size); return size; } diff --git a/net/rxrpc/local_event.c b/net/rxrpc/local_event.c index 19e929c7c38b..5e69ea6b233d 100644 --- a/net/rxrpc/local_event.c +++ b/net/rxrpc/local_event.c @@ -21,9 +21,9 @@ static const char rxrpc_version_string[65] = "linux-" UTS_RELEASE " AF_RXRPC"; /* * Reply to a version request */ -static void rxrpc_send_version_request(struct rxrpc_local *local, - struct rxrpc_host_header *hdr, - struct sk_buff *skb) +void rxrpc_send_version_request(struct rxrpc_local *local, + struct rxrpc_host_header *hdr, + struct sk_buff *skb) { struct rxrpc_wire_header whdr; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); @@ -63,8 +63,6 @@ static void rxrpc_send_version_request(struct rxrpc_local *local, len = iov[0].iov_len + iov[1].iov_len; - _proto("Tx VERSION (reply)"); - ret = kernel_sendmsg(local->socket, &msg, iov, 2, len); if (ret < 0) trace_rxrpc_tx_fail(local->debug_id, 0, ret, @@ -75,41 +73,3 @@ static void rxrpc_send_version_request(struct rxrpc_local *local, _leave(""); } - -/* - * Process event packets targeted at a local endpoint. - */ -void rxrpc_process_local_events(struct rxrpc_local *local) -{ - struct sk_buff *skb; - char v; - - _enter(""); - - skb = skb_dequeue(&local->event_queue); - if (skb) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - - rxrpc_see_skb(skb, rxrpc_skb_seen); - _debug("{%d},{%u}", local->debug_id, sp->hdr.type); - - switch (sp->hdr.type) { - case RXRPC_PACKET_TYPE_VERSION: - if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), - &v, 1) < 0) - return; - _proto("Rx VERSION { %02x }", v); - if (v == 0) - rxrpc_send_version_request(local, &sp->hdr, skb); - break; - - default: - /* Just ignore anything we don't understand */ - break; - } - - rxrpc_free_skb(skb, rxrpc_skb_freed); - } - - _leave(""); -} diff --git a/net/rxrpc/local_object.c b/net/rxrpc/local_object.c index 38ea98ff426b..b8eaca5d9f22 100644 --- a/net/rxrpc/local_object.c +++ b/net/rxrpc/local_object.c @@ -20,10 +20,23 @@ #include <net/af_rxrpc.h> #include "ar-internal.h" -static void rxrpc_local_processor(struct work_struct *); static void rxrpc_local_rcu(struct rcu_head *); /* + * Handle an ICMP/ICMP6 error turning up at the tunnel. Push it through the + * usual mechanism so that it gets parsed and presented through the UDP + * socket's error_report(). + */ +static void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, int err, + __be16 port, u32 info, u8 *payload) +{ + if (ip_hdr(skb)->version == IPVERSION) + return ip_icmp_error(sk, skb, err, port, info, payload); + if (IS_ENABLED(CONFIG_AF_RXRPC_IPV6)) + return ipv6_icmp_error(sk, skb, err, port, info, payload); +} + +/* * Compare a local to an address. Return -ve, 0 or +ve to indicate less than, * same or greater than. * @@ -69,32 +82,60 @@ static long rxrpc_local_cmp_key(const struct rxrpc_local *local, } } +static void rxrpc_client_conn_reap_timeout(struct timer_list *timer) +{ + struct rxrpc_local *local = + container_of(timer, struct rxrpc_local, client_conn_reap_timer); + + if (local->kill_all_client_conns && + test_and_set_bit(RXRPC_CLIENT_CONN_REAP_TIMER, &local->client_conn_flags)) + rxrpc_wake_up_io_thread(local); +} + /* * Allocate a new local endpoint. */ -static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet, +static struct rxrpc_local *rxrpc_alloc_local(struct net *net, const struct sockaddr_rxrpc *srx) { struct rxrpc_local *local; + u32 tmp; local = kzalloc(sizeof(struct rxrpc_local), GFP_KERNEL); if (local) { refcount_set(&local->ref, 1); atomic_set(&local->active_users, 1); - local->rxnet = rxnet; + local->net = net; + local->rxnet = rxrpc_net(net); INIT_HLIST_NODE(&local->link); - INIT_WORK(&local->processor, rxrpc_local_processor); init_rwsem(&local->defrag_sem); - skb_queue_head_init(&local->reject_queue); - skb_queue_head_init(&local->event_queue); + init_completion(&local->io_thread_ready); + skb_queue_head_init(&local->rx_queue); + INIT_LIST_HEAD(&local->conn_attend_q); + INIT_LIST_HEAD(&local->call_attend_q); + local->client_bundles = RB_ROOT; spin_lock_init(&local->client_bundles_lock); + local->kill_all_client_conns = false; + INIT_LIST_HEAD(&local->idle_client_conns); + timer_setup(&local->client_conn_reap_timer, + rxrpc_client_conn_reap_timeout, 0); + spin_lock_init(&local->lock); rwlock_init(&local->services_lock); local->debug_id = atomic_inc_return(&rxrpc_debug_id); memcpy(&local->srx, srx, sizeof(*srx)); local->srx.srx_service = 0; - trace_rxrpc_local(local->debug_id, rxrpc_local_new, 1, NULL); + idr_init(&local->conn_ids); + get_random_bytes(&tmp, sizeof(tmp)); + tmp &= 0x3fffffff; + if (tmp == 0) + tmp = 1; + idr_set_cursor(&local->conn_ids, tmp); + INIT_LIST_HEAD(&local->new_client_calls); + spin_lock_init(&local->client_call_lock); + + trace_rxrpc_local(local->debug_id, rxrpc_local_new, 1, 1); } _leave(" = %p", local); @@ -110,6 +151,7 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) struct udp_tunnel_sock_cfg tuncfg = {NULL}; struct sockaddr_rxrpc *srx = &local->srx; struct udp_port_cfg udp_conf = {0}; + struct task_struct *io_thread; struct sock *usk; int ret; @@ -136,7 +178,7 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) } tuncfg.encap_type = UDP_ENCAP_RXRPC; - tuncfg.encap_rcv = rxrpc_input_packet; + tuncfg.encap_rcv = rxrpc_encap_rcv; tuncfg.encap_err_rcv = rxrpc_encap_err_rcv; tuncfg.sk_user_data = local; setup_udp_tunnel_sock(net, local->socket, &tuncfg); @@ -169,8 +211,24 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) BUG(); } + io_thread = kthread_run(rxrpc_io_thread, local, + "krxrpcio/%u", ntohs(udp_conf.local_udp_port)); + if (IS_ERR(io_thread)) { + ret = PTR_ERR(io_thread); + goto error_sock; + } + + wait_for_completion(&local->io_thread_ready); + local->io_thread = io_thread; _leave(" = 0"); return 0; + +error_sock: + kernel_sock_shutdown(local->socket, SHUT_RDWR); + local->socket->sk->sk_user_data = NULL; + sock_release(local->socket); + local->socket = NULL; + return ret; } /* @@ -182,7 +240,6 @@ struct rxrpc_local *rxrpc_lookup_local(struct net *net, struct rxrpc_local *local; struct rxrpc_net *rxnet = rxrpc_net(net); struct hlist_node *cursor; - const char *age; long diff; int ret; @@ -213,14 +270,13 @@ struct rxrpc_local *rxrpc_lookup_local(struct net *net, * we're attempting to use a local address that the dying * object is still using. */ - if (!rxrpc_use_local(local)) + if (!rxrpc_use_local(local, rxrpc_local_use_lookup)) break; - age = "old"; goto found; } - local = rxrpc_alloc_local(rxnet, srx); + local = rxrpc_alloc_local(net, srx); if (!local) goto nomem; @@ -234,14 +290,9 @@ struct rxrpc_local *rxrpc_lookup_local(struct net *net, } else { hlist_add_head_rcu(&local->link, &rxnet->local_endpoints); } - age = "new"; found: mutex_unlock(&rxnet->local_mutex); - - _net("LOCAL %s %d {%pISp}", - age, local->debug_id, &local->srx.transport); - _leave(" = %p", local); return local; @@ -263,64 +314,49 @@ addr_in_use: /* * Get a ref on a local endpoint. */ -struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local) +struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local, + enum rxrpc_local_trace why) { - const void *here = __builtin_return_address(0); - int r; + int r, u; + u = atomic_read(&local->active_users); __refcount_inc(&local->ref, &r); - trace_rxrpc_local(local->debug_id, rxrpc_local_got, r + 1, here); + trace_rxrpc_local(local->debug_id, why, r + 1, u); return local; } /* * Get a ref on a local endpoint unless its usage has already reached 0. */ -struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local) +struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local, + enum rxrpc_local_trace why) { - const void *here = __builtin_return_address(0); - int r; + int r, u; - if (local) { - if (__refcount_inc_not_zero(&local->ref, &r)) - trace_rxrpc_local(local->debug_id, rxrpc_local_got, - r + 1, here); - else - local = NULL; + if (local && __refcount_inc_not_zero(&local->ref, &r)) { + u = atomic_read(&local->active_users); + trace_rxrpc_local(local->debug_id, why, r + 1, u); + return local; } - return local; -} -/* - * Queue a local endpoint and pass the caller's reference to the work item. - */ -void rxrpc_queue_local(struct rxrpc_local *local) -{ - const void *here = __builtin_return_address(0); - unsigned int debug_id = local->debug_id; - int r = refcount_read(&local->ref); - - if (rxrpc_queue_work(&local->processor)) - trace_rxrpc_local(debug_id, rxrpc_local_queued, r + 1, here); - else - rxrpc_put_local(local); + return NULL; } /* * Drop a ref on a local endpoint. */ -void rxrpc_put_local(struct rxrpc_local *local) +void rxrpc_put_local(struct rxrpc_local *local, enum rxrpc_local_trace why) { - const void *here = __builtin_return_address(0); unsigned int debug_id; bool dead; - int r; + int r, u; if (local) { debug_id = local->debug_id; + u = atomic_read(&local->active_users); dead = __refcount_dec_and_test(&local->ref, &r); - trace_rxrpc_local(debug_id, rxrpc_local_put, r, here); + trace_rxrpc_local(debug_id, why, r, u); if (dead) call_rcu(&local->rcu, rxrpc_local_rcu); @@ -330,14 +366,15 @@ void rxrpc_put_local(struct rxrpc_local *local) /* * Start using a local endpoint. */ -struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local) +struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local, + enum rxrpc_local_trace why) { - local = rxrpc_get_local_maybe(local); + local = rxrpc_get_local_maybe(local, rxrpc_local_get_for_use); if (!local) return NULL; - if (!__rxrpc_use_local(local)) { - rxrpc_put_local(local); + if (!__rxrpc_use_local(local, why)) { + rxrpc_put_local(local, rxrpc_local_put_for_use); return NULL; } @@ -346,15 +383,20 @@ struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local) /* * Cease using a local endpoint. Once the number of active users reaches 0, we - * start the closure of the transport in the work processor. + * start the closure of the transport in the I/O thread.. */ -void rxrpc_unuse_local(struct rxrpc_local *local) +void rxrpc_unuse_local(struct rxrpc_local *local, enum rxrpc_local_trace why) { + unsigned int debug_id; + int r, u; + if (local) { - if (__rxrpc_unuse_local(local)) { - rxrpc_get_local(local); - rxrpc_queue_local(local); - } + debug_id = local->debug_id; + r = refcount_read(&local->ref); + u = atomic_dec_return(&local->active_users); + trace_rxrpc_local(debug_id, why, r, u); + if (u == 0) + kthread_stop(local->io_thread); } } @@ -365,7 +407,7 @@ void rxrpc_unuse_local(struct rxrpc_local *local) * Closing the socket cannot be done from bottom half context or RCU callback * context because it might sleep. */ -static void rxrpc_local_destroyer(struct rxrpc_local *local) +void rxrpc_destroy_local(struct rxrpc_local *local) { struct socket *socket = local->socket; struct rxrpc_net *rxnet = local->rxnet; @@ -392,47 +434,8 @@ static void rxrpc_local_destroyer(struct rxrpc_local *local) /* At this point, there should be no more packets coming in to the * local endpoint. */ - rxrpc_purge_queue(&local->reject_queue); - rxrpc_purge_queue(&local->event_queue); -} - -/* - * Process events on an endpoint. The work item carries a ref which - * we must release. - */ -static void rxrpc_local_processor(struct work_struct *work) -{ - struct rxrpc_local *local = - container_of(work, struct rxrpc_local, processor); - bool again; - - if (local->dead) - return; - - trace_rxrpc_local(local->debug_id, rxrpc_local_processing, - refcount_read(&local->ref), NULL); - - do { - again = false; - if (!__rxrpc_use_local(local)) { - rxrpc_local_destroyer(local); - break; - } - - if (!skb_queue_empty(&local->reject_queue)) { - rxrpc_reject_packets(local); - again = true; - } - - if (!skb_queue_empty(&local->event_queue)) { - rxrpc_process_local_events(local); - again = true; - } - - __rxrpc_unuse_local(local); - } while (again); - - rxrpc_put_local(local); + rxrpc_purge_queue(&local->rx_queue); + rxrpc_purge_client_connections(local); } /* @@ -442,13 +445,8 @@ static void rxrpc_local_rcu(struct rcu_head *rcu) { struct rxrpc_local *local = container_of(rcu, struct rxrpc_local, rcu); - _enter("%d", local->debug_id); - - ASSERT(!work_pending(&local->processor)); - - _net("DESTROY LOCAL %d", local->debug_id); + rxrpc_see_local(local, rxrpc_local_free); kfree(local); - _leave(""); } /* diff --git a/net/rxrpc/misc.c b/net/rxrpc/misc.c index d4144fd86f84..056c428d8bf3 100644 --- a/net/rxrpc/misc.c +++ b/net/rxrpc/misc.c @@ -17,12 +17,6 @@ unsigned int rxrpc_max_backlog __read_mostly = 10; /* - * How long to wait before scheduling ACK generation after seeing a - * packet with RXRPC_REQUEST_ACK set (in jiffies). - */ -unsigned long rxrpc_requested_ack_delay = 1; - -/* * How long to wait before scheduling an ACK with subtype DELAY (in jiffies). * * We use this when we've received new data packets. If those packets aren't @@ -46,10 +40,7 @@ unsigned long rxrpc_idle_ack_delay = HZ / 2; * limit is hit, we should generate an EXCEEDS_WINDOW ACK and discard further * packets. */ -unsigned int rxrpc_rx_window_size = RXRPC_INIT_RX_WINDOW_SIZE; -#if (RXRPC_RXTX_BUFF_SIZE - 1) < RXRPC_INIT_RX_WINDOW_SIZE -#error Need to reduce RXRPC_INIT_RX_WINDOW_SIZE -#endif +unsigned int rxrpc_rx_window_size = 255; /* * Maximum Rx MTU size. This indicates to the sender the size of jumbo packet @@ -62,15 +53,3 @@ unsigned int rxrpc_rx_mtu = 5692; * sender that we're willing to handle. */ unsigned int rxrpc_rx_jumbo_max = 4; - -const s8 rxrpc_ack_priority[] = { - [0] = 0, - [RXRPC_ACK_DELAY] = 1, - [RXRPC_ACK_REQUESTED] = 2, - [RXRPC_ACK_IDLE] = 3, - [RXRPC_ACK_DUPLICATE] = 4, - [RXRPC_ACK_OUT_OF_SEQUENCE] = 5, - [RXRPC_ACK_EXCEEDS_WINDOW] = 6, - [RXRPC_ACK_NOSPACE] = 7, - [RXRPC_ACK_PING_RESPONSE] = 8, -}; diff --git a/net/rxrpc/net_ns.c b/net/rxrpc/net_ns.c index bb4c25d6df64..a0319c040c25 100644 --- a/net/rxrpc/net_ns.c +++ b/net/rxrpc/net_ns.c @@ -10,15 +10,6 @@ unsigned int rxrpc_net_id; -static void rxrpc_client_conn_reap_timeout(struct timer_list *timer) -{ - struct rxrpc_net *rxnet = - container_of(timer, struct rxrpc_net, client_conn_reap_timer); - - if (rxnet->live) - rxrpc_queue_work(&rxnet->client_conn_reaper); -} - static void rxrpc_service_conn_reap_timeout(struct timer_list *timer) { struct rxrpc_net *rxnet = @@ -63,14 +54,6 @@ static __net_init int rxrpc_init_net(struct net *net) rxrpc_service_conn_reap_timeout, 0); atomic_set(&rxnet->nr_client_conns, 0); - rxnet->kill_all_client_conns = false; - spin_lock_init(&rxnet->client_conn_cache_lock); - spin_lock_init(&rxnet->client_conn_discard_lock); - INIT_LIST_HEAD(&rxnet->idle_client_conns); - INIT_WORK(&rxnet->client_conn_reaper, - rxrpc_discard_expired_client_conns); - timer_setup(&rxnet->client_conn_reap_timer, - rxrpc_client_conn_reap_timeout, 0); INIT_HLIST_HEAD(&rxnet->local_endpoints); mutex_init(&rxnet->local_mutex); @@ -101,6 +84,8 @@ static __net_init int rxrpc_init_net(struct net *net) proc_create_net("locals", 0444, rxnet->proc_net, &rxrpc_local_seq_ops, sizeof(struct seq_net_private)); + proc_create_net_single_write("stats", S_IFREG | 0644, rxnet->proc_net, + rxrpc_stats_show, rxrpc_stats_clear, NULL); return 0; err_proc: diff --git a/net/rxrpc/output.c b/net/rxrpc/output.c index 9683617db704..a9746be29634 100644 --- a/net/rxrpc/output.c +++ b/net/rxrpc/output.c @@ -13,15 +13,27 @@ #include <linux/export.h> #include <net/sock.h> #include <net/af_rxrpc.h> +#include <net/udp.h> #include "ar-internal.h" -struct rxrpc_ack_buffer { - struct rxrpc_wire_header whdr; - struct rxrpc_ackpacket ack; - u8 acks[255]; - u8 pad[3]; - struct rxrpc_ackinfo ackinfo; -}; +extern int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len); + +static ssize_t do_udp_sendmsg(struct socket *socket, struct msghdr *msg, size_t len) +{ + struct sockaddr *sa = msg->msg_name; + struct sock *sk = socket->sk; + + if (IS_ENABLED(CONFIG_AF_RXRPC_IPV6)) { + if (sa->sa_family == AF_INET6) { + if (sk->sk_family != AF_INET6) { + pr_warn("AF_INET6 address on AF_INET socket\n"); + return -ENOPROTOOPT; + } + return udpv6_sendmsg(sk, msg, len); + } + } + return udp_sendmsg(sk, msg, len); +} struct rxrpc_abort_buffer { struct rxrpc_wire_header whdr; @@ -68,66 +80,83 @@ static void rxrpc_set_keepalive(struct rxrpc_call *call) */ static size_t rxrpc_fill_out_ack(struct rxrpc_connection *conn, struct rxrpc_call *call, - struct rxrpc_ack_buffer *pkt, - rxrpc_seq_t *_hard_ack, - rxrpc_seq_t *_top, - u8 reason) + struct rxrpc_txbuf *txb) { - rxrpc_serial_t serial; - unsigned int tmp; - rxrpc_seq_t hard_ack, top, seq; - int ix; + struct rxrpc_ackinfo ackinfo; + unsigned int qsize; + rxrpc_seq_t window, wtop, wrap_point, ix, first; + int rsize; + u64 wtmp; u32 mtu, jmax; - u8 *ackp = pkt->acks; + u8 *ackp = txb->acks; + u8 sack_buffer[sizeof(call->ackr_sack_table)] __aligned(8); - tmp = atomic_xchg(&call->ackr_nr_unacked, 0); - tmp |= atomic_xchg(&call->ackr_nr_consumed, 0); - if (!tmp && (reason == RXRPC_ACK_DELAY || - reason == RXRPC_ACK_IDLE)) - return 0; + atomic_set(&call->ackr_nr_unacked, 0); + atomic_set(&call->ackr_nr_consumed, 0); + rxrpc_inc_stat(call->rxnet, stat_tx_ack_fill); /* Barrier against rxrpc_input_data(). */ - serial = call->ackr_serial; - hard_ack = READ_ONCE(call->rx_hard_ack); - top = smp_load_acquire(&call->rx_top); - *_hard_ack = hard_ack; - *_top = top; - - pkt->ack.bufferSpace = htons(8); - pkt->ack.maxSkew = htons(0); - pkt->ack.firstPacket = htonl(hard_ack + 1); - pkt->ack.previousPacket = htonl(call->ackr_highest_seq); - pkt->ack.serial = htonl(serial); - pkt->ack.reason = reason; - pkt->ack.nAcks = top - hard_ack; - - if (reason == RXRPC_ACK_PING) - pkt->whdr.flags |= RXRPC_REQUEST_ACK; - - if (after(top, hard_ack)) { - seq = hard_ack + 1; - do { - ix = seq & RXRPC_RXTX_BUFF_MASK; - if (call->rxtx_buffer[ix]) - *ackp++ = RXRPC_ACK_TYPE_ACK; - else - *ackp++ = RXRPC_ACK_TYPE_NACK; - seq++; - } while (before_eq(seq, top)); +retry: + wtmp = atomic64_read_acquire(&call->ackr_window); + window = lower_32_bits(wtmp); + wtop = upper_32_bits(wtmp); + txb->ack.firstPacket = htonl(window); + txb->ack.nAcks = 0; + + if (after(wtop, window)) { + /* Try to copy the SACK ring locklessly. We can use the copy, + * only if the now-current top of the window didn't go past the + * previously read base - otherwise we can't know whether we + * have old data or new data. + */ + memcpy(sack_buffer, call->ackr_sack_table, sizeof(sack_buffer)); + wrap_point = window + RXRPC_SACK_SIZE - 1; + wtmp = atomic64_read_acquire(&call->ackr_window); + window = lower_32_bits(wtmp); + wtop = upper_32_bits(wtmp); + if (after(wtop, wrap_point)) { + cond_resched(); + goto retry; + } + + /* The buffer is maintained as a ring with an invariant mapping + * between bit position and sequence number, so we'll probably + * need to rotate it. + */ + txb->ack.nAcks = wtop - window; + ix = window % RXRPC_SACK_SIZE; + first = sizeof(sack_buffer) - ix; + + if (ix + txb->ack.nAcks <= RXRPC_SACK_SIZE) { + memcpy(txb->acks, sack_buffer + ix, txb->ack.nAcks); + } else { + memcpy(txb->acks, sack_buffer + ix, first); + memcpy(txb->acks + first, sack_buffer, + txb->ack.nAcks - first); + } + + ackp += txb->ack.nAcks; + } else if (before(wtop, window)) { + pr_warn("ack window backward %x %x", window, wtop); + } else if (txb->ack.reason == RXRPC_ACK_DELAY) { + txb->ack.reason = RXRPC_ACK_IDLE; } - mtu = conn->params.peer->if_mtu; - mtu -= conn->params.peer->hdrsize; - jmax = (call->nr_jumbo_bad > 3) ? 1 : rxrpc_rx_jumbo_max; - pkt->ackinfo.rxMTU = htonl(rxrpc_rx_mtu); - pkt->ackinfo.maxMTU = htonl(mtu); - pkt->ackinfo.rwind = htonl(call->rx_winsize); - pkt->ackinfo.jumbo_max = htonl(jmax); + mtu = conn->peer->if_mtu; + mtu -= conn->peer->hdrsize; + jmax = rxrpc_rx_jumbo_max; + qsize = (window - 1) - call->rx_consumed; + rsize = max_t(int, call->rx_winsize - qsize, 0); + ackinfo.rxMTU = htonl(rxrpc_rx_mtu); + ackinfo.maxMTU = htonl(mtu); + ackinfo.rwind = htonl(rsize); + ackinfo.jumbo_max = htonl(jmax); *ackp++ = 0; *ackp++ = 0; *ackp++ = 0; - return top - hard_ack + 3; + memcpy(ackp, &ackinfo, sizeof(ackinfo)); + return txb->ack.nAcks + 3 + sizeof(ackinfo); } /* @@ -174,28 +203,20 @@ static void rxrpc_cancel_rtt_probe(struct rxrpc_call *call, } /* - * Send an ACK call packet. + * Transmit an ACK packet. */ -int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, - rxrpc_serial_t *_serial) +int rxrpc_send_ack_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) { struct rxrpc_connection *conn; - struct rxrpc_ack_buffer *pkt; struct msghdr msg; - struct kvec iov[2]; + struct kvec iov[1]; rxrpc_serial_t serial; - rxrpc_seq_t hard_ack, top; size_t len, n; int ret, rtt_slot = -1; - u8 reason; if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) return -ECONNRESET; - pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); - if (!pkt) - return -ENOMEM; - conn = call->conn; msg.msg_name = &call->peer->srx.transport; @@ -204,79 +225,48 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, msg.msg_controllen = 0; msg.msg_flags = 0; - pkt->whdr.epoch = htonl(conn->proto.epoch); - pkt->whdr.cid = htonl(call->cid); - pkt->whdr.callNumber = htonl(call->call_id); - pkt->whdr.seq = 0; - pkt->whdr.type = RXRPC_PACKET_TYPE_ACK; - pkt->whdr.flags = RXRPC_SLOW_START_OK | conn->out_clientflag; - pkt->whdr.userStatus = 0; - pkt->whdr.securityIndex = call->security_ix; - pkt->whdr._rsvd = 0; - pkt->whdr.serviceId = htons(call->service_id); - - spin_lock_bh(&call->lock); - if (ping) { - reason = RXRPC_ACK_PING; - } else { - reason = call->ackr_reason; - if (!call->ackr_reason) { - spin_unlock_bh(&call->lock); - ret = 0; - goto out; - } - call->ackr_reason = 0; - } - n = rxrpc_fill_out_ack(conn, call, pkt, &hard_ack, &top, reason); + if (txb->ack.reason == RXRPC_ACK_PING) + txb->wire.flags |= RXRPC_REQUEST_ACK; - spin_unlock_bh(&call->lock); - if (n == 0) { - kfree(pkt); + n = rxrpc_fill_out_ack(conn, call, txb); + if (n == 0) return 0; - } - iov[0].iov_base = pkt; - iov[0].iov_len = sizeof(pkt->whdr) + sizeof(pkt->ack) + n; - iov[1].iov_base = &pkt->ackinfo; - iov[1].iov_len = sizeof(pkt->ackinfo); - len = iov[0].iov_len + iov[1].iov_len; + iov[0].iov_base = &txb->wire; + iov[0].iov_len = sizeof(txb->wire) + sizeof(txb->ack) + n; + len = iov[0].iov_len; serial = atomic_inc_return(&conn->serial); - pkt->whdr.serial = htonl(serial); + txb->wire.serial = htonl(serial); trace_rxrpc_tx_ack(call->debug_id, serial, - ntohl(pkt->ack.firstPacket), - ntohl(pkt->ack.serial), - pkt->ack.reason, pkt->ack.nAcks); - if (_serial) - *_serial = serial; + ntohl(txb->ack.firstPacket), + ntohl(txb->ack.serial), txb->ack.reason, txb->ack.nAcks); - if (ping) + if (txb->ack.reason == RXRPC_ACK_PING) rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_ping); - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); - conn->params.peer->last_tx_at = ktime_get_seconds(); + rxrpc_inc_stat(call->rxnet, stat_tx_ack_send); + + /* Grab the highest received seq as late as possible */ + txb->ack.previousPacket = htonl(call->rx_highest_seq); + + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 1, len); + ret = do_udp_sendmsg(conn->local->socket, &msg, len); + call->peer->last_tx_at = ktime_get_seconds(); if (ret < 0) trace_rxrpc_tx_fail(call->debug_id, serial, ret, rxrpc_tx_point_call_ack); else - trace_rxrpc_tx_packet(call->debug_id, &pkt->whdr, + trace_rxrpc_tx_packet(call->debug_id, &txb->wire, rxrpc_tx_point_call_ack); rxrpc_tx_backoff(call, ret); - if (call->state < RXRPC_CALL_COMPLETE) { - if (ret < 0) { + if (!__rxrpc_call_is_complete(call)) { + if (ret < 0) rxrpc_cancel_rtt_probe(call, serial, rtt_slot); - rxrpc_propose_ACK(call, pkt->ack.reason, - ntohl(pkt->ack.serial), - false, true, - rxrpc_propose_ack_retry_tx); - } - rxrpc_set_keepalive(call); } -out: - kfree(pkt); return ret; } @@ -299,7 +289,7 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) * channel instead, thereby closing off this call. */ if (rxrpc_is_client_call(call) && - test_bit(RXRPC_CALL_TX_LAST, &call->flags)) + test_bit(RXRPC_CALL_TX_ALL_ACKED, &call->flags)) return 0; if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) @@ -322,7 +312,7 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) pkt.whdr.userStatus = 0; pkt.whdr.securityIndex = call->security_ix; pkt.whdr._rsvd = 0; - pkt.whdr.serviceId = htons(call->service_id); + pkt.whdr.serviceId = htons(call->dest_srx.srx_service); pkt.abort_code = htonl(call->abort_code); iov[0].iov_base = &pkt; @@ -331,9 +321,9 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) serial = atomic_inc_return(&conn->serial); pkt.whdr.serial = htonl(serial); - ret = kernel_sendmsg(conn->params.local->socket, - &msg, iov, 1, sizeof(pkt)); - conn->params.peer->last_tx_at = ktime_get_seconds(); + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 1, sizeof(pkt)); + ret = do_udp_sendmsg(conn->local->socket, &msg, sizeof(pkt)); + conn->peer->last_tx_at = ktime_get_seconds(); if (ret < 0) trace_rxrpc_tx_fail(call->debug_id, serial, ret, rxrpc_tx_point_call_abort); @@ -347,50 +337,30 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) /* * send a packet through the transport endpoint */ -int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, - bool retrans) +int rxrpc_send_data_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) { + enum rxrpc_req_ack_trace why; struct rxrpc_connection *conn = call->conn; - struct rxrpc_wire_header whdr; - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct msghdr msg; - struct kvec iov[2]; + struct kvec iov[1]; rxrpc_serial_t serial; size_t len; int ret, rtt_slot = -1; - _enter(",{%d}", skb->len); - - if (hlist_unhashed(&call->error_link)) { - spin_lock_bh(&call->peer->lock); - hlist_add_head_rcu(&call->error_link, &call->peer->error_targets); - spin_unlock_bh(&call->peer->lock); - } + _enter("%x,{%d}", txb->seq, txb->len); /* Each transmission of a Tx packet needs a new serial number */ serial = atomic_inc_return(&conn->serial); - - whdr.epoch = htonl(conn->proto.epoch); - whdr.cid = htonl(call->cid); - whdr.callNumber = htonl(call->call_id); - whdr.seq = htonl(sp->hdr.seq); - whdr.serial = htonl(serial); - whdr.type = RXRPC_PACKET_TYPE_DATA; - whdr.flags = sp->hdr.flags; - whdr.userStatus = 0; - whdr.securityIndex = call->security_ix; - whdr._rsvd = htons(sp->hdr._rsvd); - whdr.serviceId = htons(call->service_id); + txb->wire.serial = htonl(serial); if (test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags) && - sp->hdr.seq == 1) - whdr.userStatus = RXRPC_USERSTATUS_SERVICE_UPGRADE; + txb->seq == 1) + txb->wire.userStatus = RXRPC_USERSTATUS_SERVICE_UPGRADE; - iov[0].iov_base = &whdr; - iov[0].iov_len = sizeof(whdr); - iov[1].iov_base = skb->head; - iov[1].iov_len = skb->len; - len = iov[0].iov_len + iov[1].iov_len; + iov[0].iov_base = &txb->wire; + iov[0].iov_len = sizeof(txb->wire) + txb->len; + len = iov[0].iov_len; + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 1, len); msg.msg_name = &call->peer->srx.transport; msg.msg_namelen = call->peer->srx.transport_len; @@ -405,41 +375,64 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, * service call, lest OpenAFS incorrectly send us an ACK with some * soft-ACKs in it and then never follow up with a proper hard ACK. */ - if ((!(sp->hdr.flags & RXRPC_LAST_PACKET) || - rxrpc_to_server(sp) - ) && - (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events) || - retrans || - call->cong_mode == RXRPC_CALL_SLOW_START || - (call->peer->rtt_count < 3 && sp->hdr.seq & 1) || - ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), - ktime_get_real()))) - whdr.flags |= RXRPC_REQUEST_ACK; + if (txb->wire.flags & RXRPC_REQUEST_ACK) + why = rxrpc_reqack_already_on; + else if (test_bit(RXRPC_TXBUF_LAST, &txb->flags) && rxrpc_sending_to_client(txb)) + why = rxrpc_reqack_no_srv_last; + else if (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events)) + why = rxrpc_reqack_ack_lost; + else if (test_bit(RXRPC_TXBUF_RESENT, &txb->flags)) + why = rxrpc_reqack_retrans; + else if (call->cong_mode == RXRPC_CALL_SLOW_START && call->cong_cwnd <= 2) + why = rxrpc_reqack_slow_start; + else if (call->tx_winsize <= 2) + why = rxrpc_reqack_small_txwin; + else if (call->peer->rtt_count < 3 && txb->seq & 1) + why = rxrpc_reqack_more_rtt; + else if (ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), ktime_get_real())) + why = rxrpc_reqack_old_rtt; + else + goto dont_set_request_ack; + + rxrpc_inc_stat(call->rxnet, stat_why_req_ack[why]); + trace_rxrpc_req_ack(call->debug_id, txb->seq, why); + if (why != rxrpc_reqack_no_srv_last) + txb->wire.flags |= RXRPC_REQUEST_ACK; +dont_set_request_ack: if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) { static int lose; if ((lose++ & 7) == 7) { ret = 0; - trace_rxrpc_tx_data(call, sp->hdr.seq, serial, - whdr.flags, retrans, true); + trace_rxrpc_tx_data(call, txb->seq, serial, + txb->wire.flags, + test_bit(RXRPC_TXBUF_RESENT, &txb->flags), + true); goto done; } } - trace_rxrpc_tx_data(call, sp->hdr.seq, serial, whdr.flags, retrans, - false); + trace_rxrpc_tx_data(call, txb->seq, serial, txb->wire.flags, + test_bit(RXRPC_TXBUF_RESENT, &txb->flags), false); + + /* Track what we've attempted to transmit at least once so that the + * retransmission algorithm doesn't try to resend what we haven't sent + * yet. However, this can race as we can receive an ACK before we get + * to this point. But, OTOH, if we won't get an ACK mentioning this + * packet unless the far side received it (though it could have + * discarded it anyway and NAK'd it). + */ + cmpxchg(&call->tx_transmitted, txb->seq - 1, txb->seq); /* send the packet with the don't fragment bit set if we currently * think it's small enough */ - if (iov[1].iov_len >= call->peer->maxdata) + if (txb->len >= call->peer->maxdata) goto send_fragmentable; - down_read(&conn->params.local->defrag_sem); + down_read(&conn->local->defrag_sem); - sp->hdr.serial = serial; - smp_wmb(); /* Set serial before timestamp */ - skb->tstamp = ktime_get_real(); - if (whdr.flags & RXRPC_REQUEST_ACK) + txb->last_sent = ktime_get_real(); + if (txb->wire.flags & RXRPC_REQUEST_ACK) rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_data); /* send the packet by UDP @@ -448,16 +441,18 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, * - in which case, we'll have processed the ICMP error * message and update the peer record */ - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); - conn->params.peer->last_tx_at = ktime_get_seconds(); + rxrpc_inc_stat(call->rxnet, stat_tx_data_send); + ret = do_udp_sendmsg(conn->local->socket, &msg, len); + conn->peer->last_tx_at = ktime_get_seconds(); - up_read(&conn->params.local->defrag_sem); + up_read(&conn->local->defrag_sem); if (ret < 0) { + rxrpc_inc_stat(call->rxnet, stat_tx_data_send_fail); rxrpc_cancel_rtt_probe(call, serial, rtt_slot); trace_rxrpc_tx_fail(call->debug_id, serial, ret, rxrpc_tx_point_call_data_nofrag); } else { - trace_rxrpc_tx_packet(call->debug_id, &whdr, + trace_rxrpc_tx_packet(call->debug_id, &txb->wire, rxrpc_tx_point_call_data_nofrag); } @@ -467,8 +462,9 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, done: if (ret >= 0) { - if (whdr.flags & RXRPC_REQUEST_ACK) { - call->peer->rtt_last_req = skb->tstamp; + call->tx_last_sent = txb->last_sent; + if (txb->wire.flags & RXRPC_REQUEST_ACK) { + call->peer->rtt_last_req = txb->last_sent; if (call->peer->rtt_count > 1) { unsigned long nowj = jiffies, ack_lost_at; @@ -480,7 +476,7 @@ done: } } - if (sp->hdr.seq == 1 && + if (txb->seq == 1 && !test_and_set_bit(RXRPC_CALL_BEGAN_RX_TIMER, &call->flags)) { unsigned long nowj = jiffies, expect_rx_by; @@ -510,25 +506,23 @@ send_fragmentable: /* attempt to send this message with fragmentation enabled */ _debug("send fragment"); - down_write(&conn->params.local->defrag_sem); + down_write(&conn->local->defrag_sem); - sp->hdr.serial = serial; - smp_wmb(); /* Set serial before timestamp */ - skb->tstamp = ktime_get_real(); - if (whdr.flags & RXRPC_REQUEST_ACK) + txb->last_sent = ktime_get_real(); + if (txb->wire.flags & RXRPC_REQUEST_ACK) rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_data); - switch (conn->params.local->srx.transport.family) { + switch (conn->local->srx.transport.family) { case AF_INET6: case AF_INET: - ip_sock_set_mtu_discover(conn->params.local->socket->sk, - IP_PMTUDISC_DONT); - ret = kernel_sendmsg(conn->params.local->socket, &msg, - iov, 2, len); - conn->params.peer->last_tx_at = ktime_get_seconds(); - - ip_sock_set_mtu_discover(conn->params.local->socket->sk, - IP_PMTUDISC_DO); + ip_sock_set_mtu_discover(conn->local->socket->sk, + IP_PMTUDISC_DONT); + rxrpc_inc_stat(call->rxnet, stat_tx_data_send_frag); + ret = do_udp_sendmsg(conn->local->socket, &msg, len); + conn->peer->last_tx_at = ktime_get_seconds(); + + ip_sock_set_mtu_discover(conn->local->socket->sk, + IP_PMTUDISC_DO); break; default: @@ -536,35 +530,91 @@ send_fragmentable: } if (ret < 0) { + rxrpc_inc_stat(call->rxnet, stat_tx_data_send_fail); rxrpc_cancel_rtt_probe(call, serial, rtt_slot); trace_rxrpc_tx_fail(call->debug_id, serial, ret, rxrpc_tx_point_call_data_frag); } else { - trace_rxrpc_tx_packet(call->debug_id, &whdr, + trace_rxrpc_tx_packet(call->debug_id, &txb->wire, rxrpc_tx_point_call_data_frag); } rxrpc_tx_backoff(call, ret); - up_write(&conn->params.local->defrag_sem); + up_write(&conn->local->defrag_sem); goto done; } /* - * reject packets through the local endpoint + * Transmit a connection-level abort. */ -void rxrpc_reject_packets(struct rxrpc_local *local) +void rxrpc_send_conn_abort(struct rxrpc_connection *conn) { - struct sockaddr_rxrpc srx; - struct rxrpc_skb_priv *sp; struct rxrpc_wire_header whdr; - struct sk_buff *skb; + struct msghdr msg; + struct kvec iov[2]; + __be32 word; + size_t len; + u32 serial; + int ret; + + msg.msg_name = &conn->peer->srx.transport; + msg.msg_namelen = conn->peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + whdr.epoch = htonl(conn->proto.epoch); + whdr.cid = htonl(conn->proto.cid); + whdr.callNumber = 0; + whdr.seq = 0; + whdr.type = RXRPC_PACKET_TYPE_ABORT; + whdr.flags = conn->out_clientflag; + whdr.userStatus = 0; + whdr.securityIndex = conn->security_ix; + whdr._rsvd = 0; + whdr.serviceId = htons(conn->service_id); + + word = htonl(conn->abort_code); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = &word; + iov[1].iov_len = sizeof(word); + + len = iov[0].iov_len + iov[1].iov_len; + + serial = atomic_inc_return(&conn->serial); + whdr.serial = htonl(serial); + + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 2, len); + ret = do_udp_sendmsg(conn->local->socket, &msg, len); + if (ret < 0) { + trace_rxrpc_tx_fail(conn->debug_id, serial, ret, + rxrpc_tx_point_conn_abort); + _debug("sendmsg failed: %d", ret); + return; + } + + trace_rxrpc_tx_packet(conn->debug_id, &whdr, rxrpc_tx_point_conn_abort); + + conn->peer->last_tx_at = ktime_get_seconds(); +} + +/* + * Reject a packet through the local endpoint. + */ +void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb) +{ + struct rxrpc_wire_header whdr; + struct sockaddr_rxrpc srx; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct msghdr msg; struct kvec iov[2]; size_t size; __be32 code; int ret, ioc; - _enter("%d", local->debug_id); + rxrpc_see_skb(skb, rxrpc_skb_see_reject); iov[0].iov_base = &whdr; iov[0].iov_len = sizeof(whdr); @@ -578,52 +628,42 @@ void rxrpc_reject_packets(struct rxrpc_local *local) memset(&whdr, 0, sizeof(whdr)); - while ((skb = skb_dequeue(&local->reject_queue))) { - rxrpc_see_skb(skb, rxrpc_skb_seen); - sp = rxrpc_skb(skb); - - switch (skb->mark) { - case RXRPC_SKB_MARK_REJECT_BUSY: - whdr.type = RXRPC_PACKET_TYPE_BUSY; - size = sizeof(whdr); - ioc = 1; - break; - case RXRPC_SKB_MARK_REJECT_ABORT: - whdr.type = RXRPC_PACKET_TYPE_ABORT; - code = htonl(skb->priority); - size = sizeof(whdr) + sizeof(code); - ioc = 2; - break; - default: - rxrpc_free_skb(skb, rxrpc_skb_freed); - continue; - } - - if (rxrpc_extract_addr_from_skb(&srx, skb) == 0) { - msg.msg_namelen = srx.transport_len; - - whdr.epoch = htonl(sp->hdr.epoch); - whdr.cid = htonl(sp->hdr.cid); - whdr.callNumber = htonl(sp->hdr.callNumber); - whdr.serviceId = htons(sp->hdr.serviceId); - whdr.flags = sp->hdr.flags; - whdr.flags ^= RXRPC_CLIENT_INITIATED; - whdr.flags &= RXRPC_CLIENT_INITIATED; - - ret = kernel_sendmsg(local->socket, &msg, - iov, ioc, size); - if (ret < 0) - trace_rxrpc_tx_fail(local->debug_id, 0, ret, - rxrpc_tx_point_reject); - else - trace_rxrpc_tx_packet(local->debug_id, &whdr, - rxrpc_tx_point_reject); - } - - rxrpc_free_skb(skb, rxrpc_skb_freed); + switch (skb->mark) { + case RXRPC_SKB_MARK_REJECT_BUSY: + whdr.type = RXRPC_PACKET_TYPE_BUSY; + size = sizeof(whdr); + ioc = 1; + break; + case RXRPC_SKB_MARK_REJECT_ABORT: + whdr.type = RXRPC_PACKET_TYPE_ABORT; + code = htonl(skb->priority); + size = sizeof(whdr) + sizeof(code); + ioc = 2; + break; + default: + return; } - _leave(""); + if (rxrpc_extract_addr_from_skb(&srx, skb) == 0) { + msg.msg_namelen = srx.transport_len; + + whdr.epoch = htonl(sp->hdr.epoch); + whdr.cid = htonl(sp->hdr.cid); + whdr.callNumber = htonl(sp->hdr.callNumber); + whdr.serviceId = htons(sp->hdr.serviceId); + whdr.flags = sp->hdr.flags; + whdr.flags ^= RXRPC_CLIENT_INITIATED; + whdr.flags &= RXRPC_CLIENT_INITIATED; + + iov_iter_kvec(&msg.msg_iter, WRITE, iov, ioc, size); + ret = do_udp_sendmsg(local->socket, &msg, size); + if (ret < 0) + trace_rxrpc_tx_fail(local->debug_id, 0, ret, + rxrpc_tx_point_reject); + else + trace_rxrpc_tx_packet(local->debug_id, &whdr, + rxrpc_tx_point_reject); + } } /* @@ -664,9 +704,8 @@ void rxrpc_send_keepalive(struct rxrpc_peer *peer) len = iov[0].iov_len + iov[1].iov_len; - _proto("Tx VERSION (keepalive)"); - - ret = kernel_sendmsg(peer->local->socket, &msg, iov, 2, len); + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 2, len); + ret = do_udp_sendmsg(peer->local->socket, &msg, len); if (ret < 0) trace_rxrpc_tx_fail(peer->debug_id, 0, ret, rxrpc_tx_point_version_keepalive); @@ -677,3 +716,43 @@ void rxrpc_send_keepalive(struct rxrpc_peer *peer) peer->last_tx_at = ktime_get_seconds(); _leave(""); } + +/* + * Schedule an instant Tx resend. + */ +static inline void rxrpc_instant_resend(struct rxrpc_call *call, + struct rxrpc_txbuf *txb) +{ + if (!__rxrpc_call_is_complete(call)) + kdebug("resend"); +} + +/* + * Transmit one packet. + */ +void rxrpc_transmit_one(struct rxrpc_call *call, struct rxrpc_txbuf *txb) +{ + int ret; + + ret = rxrpc_send_data_packet(call, txb); + if (ret < 0) { + switch (ret) { + case -ENETUNREACH: + case -EHOSTUNREACH: + case -ECONNREFUSED: + rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, + 0, ret); + break; + default: + _debug("need instant resend %d", ret); + rxrpc_instant_resend(call, txb); + } + } else { + unsigned long now = jiffies; + unsigned long resend_at = now + call->peer->rto_j; + + WRITE_ONCE(call->resend_at, resend_at); + rxrpc_reduce_call_timer(call, resend_at, now, + rxrpc_timer_set_for_send); + } +} diff --git a/net/rxrpc/peer_event.c b/net/rxrpc/peer_event.c index 32561e9567fe..552ba84a255c 100644 --- a/net/rxrpc/peer_event.c +++ b/net/rxrpc/peer_event.c @@ -16,256 +16,11 @@ #include <net/sock.h> #include <net/af_rxrpc.h> #include <net/ip.h> -#include <net/icmp.h> #include "ar-internal.h" -static void rxrpc_adjust_mtu(struct rxrpc_peer *, unsigned int); -static void rxrpc_store_error(struct rxrpc_peer *, struct sock_exterr_skb *); -static void rxrpc_distribute_error(struct rxrpc_peer *, int, - enum rxrpc_call_completion); - -/* - * Find the peer associated with an ICMPv4 packet. - */ -static struct rxrpc_peer *rxrpc_lookup_peer_icmp_rcu(struct rxrpc_local *local, - struct sk_buff *skb, - unsigned int udp_offset, - unsigned int *info, - struct sockaddr_rxrpc *srx) -{ - struct iphdr *ip, *ip0 = ip_hdr(skb); - struct icmphdr *icmp = icmp_hdr(skb); - struct udphdr *udp = (struct udphdr *)(skb->data + udp_offset); - - _enter("%u,%u,%u", ip0->protocol, icmp->type, icmp->code); - - switch (icmp->type) { - case ICMP_DEST_UNREACH: - *info = ntohs(icmp->un.frag.mtu); - fallthrough; - case ICMP_TIME_EXCEEDED: - case ICMP_PARAMETERPROB: - ip = (struct iphdr *)((void *)icmp + 8); - break; - default: - return NULL; - } - - memset(srx, 0, sizeof(*srx)); - srx->transport_type = local->srx.transport_type; - srx->transport_len = local->srx.transport_len; - srx->transport.family = local->srx.transport.family; - - /* Can we see an ICMP4 packet on an ICMP6 listening socket? and vice - * versa? - */ - switch (srx->transport.family) { - case AF_INET: - srx->transport_len = sizeof(srx->transport.sin); - srx->transport.family = AF_INET; - srx->transport.sin.sin_port = udp->dest; - memcpy(&srx->transport.sin.sin_addr, &ip->daddr, - sizeof(struct in_addr)); - break; - -#ifdef CONFIG_AF_RXRPC_IPV6 - case AF_INET6: - srx->transport_len = sizeof(srx->transport.sin); - srx->transport.family = AF_INET; - srx->transport.sin.sin_port = udp->dest; - memcpy(&srx->transport.sin.sin_addr, &ip->daddr, - sizeof(struct in_addr)); - break; -#endif - - default: - WARN_ON_ONCE(1); - return NULL; - } - - _net("ICMP {%pISp}", &srx->transport); - return rxrpc_lookup_peer_rcu(local, srx); -} - -#ifdef CONFIG_AF_RXRPC_IPV6 -/* - * Find the peer associated with an ICMPv6 packet. - */ -static struct rxrpc_peer *rxrpc_lookup_peer_icmp6_rcu(struct rxrpc_local *local, - struct sk_buff *skb, - unsigned int udp_offset, - unsigned int *info, - struct sockaddr_rxrpc *srx) -{ - struct icmp6hdr *icmp = icmp6_hdr(skb); - struct ipv6hdr *ip, *ip0 = ipv6_hdr(skb); - struct udphdr *udp = (struct udphdr *)(skb->data + udp_offset); - - _enter("%u,%u,%u", ip0->nexthdr, icmp->icmp6_type, icmp->icmp6_code); - - switch (icmp->icmp6_type) { - case ICMPV6_DEST_UNREACH: - *info = ntohl(icmp->icmp6_mtu); - fallthrough; - case ICMPV6_PKT_TOOBIG: - case ICMPV6_TIME_EXCEED: - case ICMPV6_PARAMPROB: - ip = (struct ipv6hdr *)((void *)icmp + 8); - break; - default: - return NULL; - } - - memset(srx, 0, sizeof(*srx)); - srx->transport_type = local->srx.transport_type; - srx->transport_len = local->srx.transport_len; - srx->transport.family = local->srx.transport.family; - - /* Can we see an ICMP4 packet on an ICMP6 listening socket? and vice - * versa? - */ - switch (srx->transport.family) { - case AF_INET: - _net("Rx ICMP6 on v4 sock"); - srx->transport_len = sizeof(srx->transport.sin); - srx->transport.family = AF_INET; - srx->transport.sin.sin_port = udp->dest; - memcpy(&srx->transport.sin.sin_addr, - &ip->daddr.s6_addr32[3], sizeof(struct in_addr)); - break; - case AF_INET6: - _net("Rx ICMP6"); - srx->transport.sin.sin_port = udp->dest; - memcpy(&srx->transport.sin6.sin6_addr, &ip->daddr, - sizeof(struct in6_addr)); - break; - default: - WARN_ON_ONCE(1); - return NULL; - } - - _net("ICMP {%pISp}", &srx->transport); - return rxrpc_lookup_peer_rcu(local, srx); -} -#endif /* CONFIG_AF_RXRPC_IPV6 */ - -/* - * Handle an error received on the local endpoint as a tunnel. - */ -void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, - unsigned int udp_offset) -{ - struct sock_extended_err ee; - struct sockaddr_rxrpc srx; - struct rxrpc_local *local; - struct rxrpc_peer *peer; - unsigned int info = 0; - int err; - u8 version = ip_hdr(skb)->version; - u8 type = icmp_hdr(skb)->type; - u8 code = icmp_hdr(skb)->code; - - rcu_read_lock(); - local = rcu_dereference_sk_user_data(sk); - if (unlikely(!local)) { - rcu_read_unlock(); - return; - } - - rxrpc_new_skb(skb, rxrpc_skb_received); - - switch (ip_hdr(skb)->version) { - case IPVERSION: - peer = rxrpc_lookup_peer_icmp_rcu(local, skb, udp_offset, - &info, &srx); - break; -#ifdef CONFIG_AF_RXRPC_IPV6 - case 6: - peer = rxrpc_lookup_peer_icmp6_rcu(local, skb, udp_offset, - &info, &srx); - break; -#endif - default: - rcu_read_unlock(); - return; - } - - if (peer && !rxrpc_get_peer_maybe(peer)) - peer = NULL; - if (!peer) { - rcu_read_unlock(); - return; - } - - memset(&ee, 0, sizeof(ee)); - - switch (version) { - case IPVERSION: - switch (type) { - case ICMP_DEST_UNREACH: - switch (code) { - case ICMP_FRAG_NEEDED: - rxrpc_adjust_mtu(peer, info); - rcu_read_unlock(); - rxrpc_put_peer(peer); - return; - default: - break; - } - - err = EHOSTUNREACH; - if (code <= NR_ICMP_UNREACH) { - /* Might want to do something different with - * non-fatal errors - */ - //harderr = icmp_err_convert[code].fatal; - err = icmp_err_convert[code].errno; - } - break; - - case ICMP_TIME_EXCEEDED: - err = EHOSTUNREACH; - break; - default: - err = EPROTO; - break; - } - - ee.ee_origin = SO_EE_ORIGIN_ICMP; - ee.ee_type = type; - ee.ee_code = code; - ee.ee_errno = err; - break; - -#ifdef CONFIG_AF_RXRPC_IPV6 - case 6: - switch (type) { - case ICMPV6_PKT_TOOBIG: - rxrpc_adjust_mtu(peer, info); - rcu_read_unlock(); - rxrpc_put_peer(peer); - return; - } - - icmpv6_err_convert(type, code, &err); - - if (err == EACCES) - err = EHOSTUNREACH; - - ee.ee_origin = SO_EE_ORIGIN_ICMP6; - ee.ee_type = type; - ee.ee_code = code; - ee.ee_errno = err; - break; -#endif - } - - trace_rxrpc_rx_icmp(peer, &ee, &srx); - - rxrpc_distribute_error(peer, err, RXRPC_CALL_NETWORK_ERROR); - rcu_read_unlock(); - rxrpc_put_peer(peer); -} +static void rxrpc_store_error(struct rxrpc_peer *, struct sk_buff *); +static void rxrpc_distribute_error(struct rxrpc_peer *, struct sk_buff *, + enum rxrpc_call_completion, int); /* * Find the peer associated with a local error. @@ -283,6 +38,9 @@ static struct rxrpc_peer *rxrpc_lookup_peer_local_rcu(struct rxrpc_local *local, srx->transport_len = local->srx.transport_len; srx->transport.family = local->srx.transport.family; + /* Can we see an ICMP4 packet on an ICMP6 listening socket? and vice + * versa? + */ switch (srx->transport.family) { case AF_INET: srx->transport_len = sizeof(srx->transport.sin); @@ -290,13 +48,11 @@ static struct rxrpc_peer *rxrpc_lookup_peer_local_rcu(struct rxrpc_local *local, srx->transport.sin.sin_port = serr->port; switch (serr->ee.ee_origin) { case SO_EE_ORIGIN_ICMP: - _net("Rx ICMP"); memcpy(&srx->transport.sin.sin_addr, skb_network_header(skb) + serr->addr_offset, sizeof(struct in_addr)); break; case SO_EE_ORIGIN_ICMP6: - _net("Rx ICMP6 on v4 sock"); memcpy(&srx->transport.sin.sin_addr, skb_network_header(skb) + serr->addr_offset + 12, sizeof(struct in_addr)); @@ -312,14 +68,12 @@ static struct rxrpc_peer *rxrpc_lookup_peer_local_rcu(struct rxrpc_local *local, case AF_INET6: switch (serr->ee.ee_origin) { case SO_EE_ORIGIN_ICMP6: - _net("Rx ICMP6"); srx->transport.sin6.sin6_port = serr->port; memcpy(&srx->transport.sin6.sin6_addr, skb_network_header(skb) + serr->addr_offset, sizeof(struct in6_addr)); break; case SO_EE_ORIGIN_ICMP: - _net("Rx ICMP on v6 sock"); srx->transport_len = sizeof(srx->transport.sin); srx->transport.family = AF_INET; srx->transport.sin.sin_port = serr->port; @@ -348,13 +102,9 @@ static struct rxrpc_peer *rxrpc_lookup_peer_local_rcu(struct rxrpc_local *local, */ static void rxrpc_adjust_mtu(struct rxrpc_peer *peer, unsigned int mtu) { - _net("Rx ICMP Fragmentation Needed (%d)", mtu); - /* wind down the local interface MTU */ - if (mtu > 0 && peer->if_mtu == 65535 && mtu < peer->if_mtu) { + if (mtu > 0 && peer->if_mtu == 65535 && mtu < peer->if_mtu) peer->if_mtu = mtu; - _net("I/F MTU %u", mtu); - } if (mtu == 0) { /* they didn't give us a size, estimate one */ @@ -371,121 +121,66 @@ static void rxrpc_adjust_mtu(struct rxrpc_peer *peer, unsigned int mtu) } if (mtu < peer->mtu) { - spin_lock_bh(&peer->lock); + spin_lock(&peer->lock); peer->mtu = mtu; peer->maxdata = peer->mtu - peer->hdrsize; - spin_unlock_bh(&peer->lock); - _net("Net MTU %u (maxdata %u)", - peer->mtu, peer->maxdata); + spin_unlock(&peer->lock); } } /* * Handle an error received on the local endpoint. */ -void rxrpc_error_report(struct sock *sk) +void rxrpc_input_error(struct rxrpc_local *local, struct sk_buff *skb) { - struct sock_exterr_skb *serr; + struct sock_exterr_skb *serr = SKB_EXT_ERR(skb); struct sockaddr_rxrpc srx; - struct rxrpc_local *local; struct rxrpc_peer *peer = NULL; - struct sk_buff *skb; - rcu_read_lock(); - local = rcu_dereference_sk_user_data(sk); - if (unlikely(!local)) { - rcu_read_unlock(); + _enter("L=%x", local->debug_id); + + if (!skb->len && serr->ee.ee_origin == SO_EE_ORIGIN_TIMESTAMPING) { + _leave("UDP empty message"); return; } - _enter("%p{%d}", sk, local->debug_id); - /* Clear the outstanding error value on the socket so that it doesn't - * cause kernel_sendmsg() to return it later. - */ - sock_error(sk); - - skb = sock_dequeue_err_skb(sk); - if (!skb) { - rcu_read_unlock(); - _leave("UDP socket errqueue empty"); + rcu_read_lock(); + peer = rxrpc_lookup_peer_local_rcu(local, skb, &srx); + if (peer && !rxrpc_get_peer_maybe(peer, rxrpc_peer_get_input_error)) + peer = NULL; + rcu_read_unlock(); + if (!peer) return; - } - rxrpc_new_skb(skb, rxrpc_skb_received); - serr = SKB_EXT_ERR(skb); - - if (serr->ee.ee_origin == SO_EE_ORIGIN_LOCAL) { - peer = rxrpc_lookup_peer_local_rcu(local, skb, &srx); - if (peer && !rxrpc_get_peer_maybe(peer)) - peer = NULL; - if (peer) { - trace_rxrpc_rx_icmp(peer, &serr->ee, &srx); - rxrpc_store_error(peer, serr); - } + + trace_rxrpc_rx_icmp(peer, &serr->ee, &srx); + + if ((serr->ee.ee_origin == SO_EE_ORIGIN_ICMP && + serr->ee.ee_type == ICMP_DEST_UNREACH && + serr->ee.ee_code == ICMP_FRAG_NEEDED)) { + rxrpc_adjust_mtu(peer, serr->ee.ee_info); + goto out; } - rcu_read_unlock(); - rxrpc_free_skb(skb, rxrpc_skb_freed); - rxrpc_put_peer(peer); - _leave(""); + rxrpc_store_error(peer, skb); +out: + rxrpc_put_peer(peer, rxrpc_peer_put_input_error); } /* * Map an error report to error codes on the peer record. */ -static void rxrpc_store_error(struct rxrpc_peer *peer, - struct sock_exterr_skb *serr) +static void rxrpc_store_error(struct rxrpc_peer *peer, struct sk_buff *skb) { enum rxrpc_call_completion compl = RXRPC_CALL_NETWORK_ERROR; - struct sock_extended_err *ee; - int err; + struct sock_exterr_skb *serr = SKB_EXT_ERR(skb); + struct sock_extended_err *ee = &serr->ee; + int err = ee->ee_errno; _enter(""); - ee = &serr->ee; - - err = ee->ee_errno; - switch (ee->ee_origin) { - case SO_EE_ORIGIN_ICMP: - switch (ee->ee_type) { - case ICMP_DEST_UNREACH: - switch (ee->ee_code) { - case ICMP_NET_UNREACH: - _net("Rx Received ICMP Network Unreachable"); - break; - case ICMP_HOST_UNREACH: - _net("Rx Received ICMP Host Unreachable"); - break; - case ICMP_PORT_UNREACH: - _net("Rx Received ICMP Port Unreachable"); - break; - case ICMP_NET_UNKNOWN: - _net("Rx Received ICMP Unknown Network"); - break; - case ICMP_HOST_UNKNOWN: - _net("Rx Received ICMP Unknown Host"); - break; - default: - _net("Rx Received ICMP DestUnreach code=%u", - ee->ee_code); - break; - } - break; - - case ICMP_TIME_EXCEEDED: - _net("Rx Received ICMP TTL Exceeded"); - break; - - default: - _proto("Rx Received ICMP error { type=%u code=%u }", - ee->ee_type, ee->ee_code); - break; - } - break; - case SO_EE_ORIGIN_NONE: case SO_EE_ORIGIN_LOCAL: - _proto("Rx Received local error { error=%d }", err); compl = RXRPC_CALL_LOCAL_ERROR; break; @@ -493,26 +188,40 @@ static void rxrpc_store_error(struct rxrpc_peer *peer, if (err == EACCES) err = EHOSTUNREACH; fallthrough; + case SO_EE_ORIGIN_ICMP: default: - _proto("Rx Received error report { orig=%u }", ee->ee_origin); break; } - rxrpc_distribute_error(peer, err, compl); + rxrpc_distribute_error(peer, skb, compl, err); } /* * Distribute an error that occurred on a peer. */ -static void rxrpc_distribute_error(struct rxrpc_peer *peer, int error, - enum rxrpc_call_completion compl) +static void rxrpc_distribute_error(struct rxrpc_peer *peer, struct sk_buff *skb, + enum rxrpc_call_completion compl, int err) { struct rxrpc_call *call; + HLIST_HEAD(error_targets); + + spin_lock(&peer->lock); + hlist_move_list(&peer->error_targets, &error_targets); + + while (!hlist_empty(&error_targets)) { + call = hlist_entry(error_targets.first, + struct rxrpc_call, error_link); + hlist_del_init(&call->error_link); + spin_unlock(&peer->lock); - hlist_for_each_entry_rcu(call, &peer->error_targets, error_link) { - rxrpc_see_call(call); - rxrpc_set_call_completion(call, compl, 0, -error); + rxrpc_see_call(call, rxrpc_call_see_distribute_error); + rxrpc_set_call_completion(call, compl, 0, -err); + rxrpc_input_call_event(call, skb); + + spin_lock(&peer->lock); } + + spin_unlock(&peer->lock); } /* @@ -526,21 +235,23 @@ static void rxrpc_peer_keepalive_dispatch(struct rxrpc_net *rxnet, struct rxrpc_peer *peer; const u8 mask = ARRAY_SIZE(rxnet->peer_keepalive) - 1; time64_t keepalive_at; + bool use; int slot; - spin_lock_bh(&rxnet->peer_hash_lock); + spin_lock(&rxnet->peer_hash_lock); while (!list_empty(collector)) { peer = list_entry(collector->next, struct rxrpc_peer, keepalive_link); list_del_init(&peer->keepalive_link); - if (!rxrpc_get_peer_maybe(peer)) + if (!rxrpc_get_peer_maybe(peer, rxrpc_peer_get_keepalive)) continue; - if (__rxrpc_use_local(peer->local)) { - spin_unlock_bh(&rxnet->peer_hash_lock); + use = __rxrpc_use_local(peer->local, rxrpc_local_use_peer_keepalive); + spin_unlock(&rxnet->peer_hash_lock); + if (use) { keepalive_at = peer->last_tx_at + RXRPC_KEEPALIVE_TIME; slot = keepalive_at - base; _debug("%02x peer %u t=%d {%pISp}", @@ -558,15 +269,17 @@ static void rxrpc_peer_keepalive_dispatch(struct rxrpc_net *rxnet, */ slot += cursor; slot &= mask; - spin_lock_bh(&rxnet->peer_hash_lock); + spin_lock(&rxnet->peer_hash_lock); list_add_tail(&peer->keepalive_link, &rxnet->peer_keepalive[slot & mask]); - rxrpc_unuse_local(peer->local); + spin_unlock(&rxnet->peer_hash_lock); + rxrpc_unuse_local(peer->local, rxrpc_local_unuse_peer_keepalive); } - rxrpc_put_peer_locked(peer); + rxrpc_put_peer(peer, rxrpc_peer_put_keepalive); + spin_lock(&rxnet->peer_hash_lock); } - spin_unlock_bh(&rxnet->peer_hash_lock); + spin_unlock(&rxnet->peer_hash_lock); } /* @@ -596,7 +309,7 @@ void rxrpc_peer_keepalive_worker(struct work_struct *work) * second; the bucket at cursor + 1 goes at now + 1s and so * on... */ - spin_lock_bh(&rxnet->peer_hash_lock); + spin_lock(&rxnet->peer_hash_lock); list_splice_init(&rxnet->peer_keepalive_new, &collector); stop = cursor + ARRAY_SIZE(rxnet->peer_keepalive); @@ -608,7 +321,7 @@ void rxrpc_peer_keepalive_worker(struct work_struct *work) } base = now; - spin_unlock_bh(&rxnet->peer_hash_lock); + spin_unlock(&rxnet->peer_hash_lock); rxnet->peer_keepalive_base = base; rxnet->peer_keepalive_cursor = cursor; diff --git a/net/rxrpc/peer_object.c b/net/rxrpc/peer_object.c index 26d2ae9baaf2..8d7a715a0bb1 100644 --- a/net/rxrpc/peer_object.c +++ b/net/rxrpc/peer_object.c @@ -138,10 +138,8 @@ struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *local, unsigned long hash_key = rxrpc_peer_hash_key(local, srx); peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); - if (peer) { - _net("PEER %d {%pISp}", peer->debug_id, &peer->srx.transport); + if (peer) _leave(" = %p {u=%d}", peer, refcount_read(&peer->ref)); - } return peer; } @@ -149,10 +147,10 @@ struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *local, * assess the MTU size for the network interface through which this peer is * reached */ -static void rxrpc_assess_MTU_size(struct rxrpc_sock *rx, +static void rxrpc_assess_MTU_size(struct rxrpc_local *local, struct rxrpc_peer *peer) { - struct net *net = sock_net(&rx->sk); + struct net *net = local->net; struct dst_entry *dst; struct rtable *rt; struct flowi fl; @@ -207,9 +205,9 @@ static void rxrpc_assess_MTU_size(struct rxrpc_sock *rx, /* * Allocate a peer. */ -struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) +struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp, + enum rxrpc_peer_trace why) { - const void *here = __builtin_return_address(0); struct rxrpc_peer *peer; _enter(""); @@ -217,7 +215,7 @@ struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) peer = kzalloc(sizeof(struct rxrpc_peer), gfp); if (peer) { refcount_set(&peer->ref, 1); - peer->local = rxrpc_get_local(local); + peer->local = rxrpc_get_local(local, rxrpc_local_get_peer); INIT_HLIST_HEAD(&peer->error_targets); peer->service_conns = RB_ROOT; seqlock_init(&peer->service_conn_lock); @@ -227,13 +225,8 @@ struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) rxrpc_peer_init_rtt(peer); - if (RXRPC_TX_SMSS > 2190) - peer->cong_cwnd = 2; - else if (RXRPC_TX_SMSS > 1095) - peer->cong_cwnd = 3; - else - peer->cong_cwnd = 4; - trace_rxrpc_peer(peer->debug_id, rxrpc_peer_new, 1, here); + peer->cong_ssthresh = RXRPC_TX_MAX_WINDOW; + trace_rxrpc_peer(peer->debug_id, 1, why); } _leave(" = %p", peer); @@ -243,11 +236,11 @@ struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) /* * Initialise peer record. */ -static void rxrpc_init_peer(struct rxrpc_sock *rx, struct rxrpc_peer *peer, +static void rxrpc_init_peer(struct rxrpc_local *local, struct rxrpc_peer *peer, unsigned long hash_key) { peer->hash_key = hash_key; - rxrpc_assess_MTU_size(rx, peer); + rxrpc_assess_MTU_size(local, peer); peer->mtu = peer->if_mtu; peer->rtt_last_req = ktime_get_real(); @@ -279,8 +272,7 @@ static void rxrpc_init_peer(struct rxrpc_sock *rx, struct rxrpc_peer *peer, /* * Set up a new peer. */ -static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_sock *rx, - struct rxrpc_local *local, +static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_local *local, struct sockaddr_rxrpc *srx, unsigned long hash_key, gfp_t gfp) @@ -289,10 +281,10 @@ static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_sock *rx, _enter(""); - peer = rxrpc_alloc_peer(local, gfp); + peer = rxrpc_alloc_peer(local, gfp, rxrpc_peer_new_client); if (peer) { memcpy(&peer->srx, srx, sizeof(*srx)); - rxrpc_init_peer(rx, peer, hash_key); + rxrpc_init_peer(local, peer, hash_key); } _leave(" = %p", peer); @@ -301,7 +293,8 @@ static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_sock *rx, static void rxrpc_free_peer(struct rxrpc_peer *peer) { - rxrpc_put_local(peer->local); + trace_rxrpc_peer(peer->debug_id, 0, rxrpc_peer_free); + rxrpc_put_local(peer->local, rxrpc_local_put_peer); kfree_rcu(peer, rcu); } @@ -310,14 +303,13 @@ static void rxrpc_free_peer(struct rxrpc_peer *peer) * since we've already done a search in the list from the non-reentrant context * (the data_ready handler) that is the only place we can add new peers. */ -void rxrpc_new_incoming_peer(struct rxrpc_sock *rx, struct rxrpc_local *local, - struct rxrpc_peer *peer) +void rxrpc_new_incoming_peer(struct rxrpc_local *local, struct rxrpc_peer *peer) { struct rxrpc_net *rxnet = local->rxnet; unsigned long hash_key; hash_key = rxrpc_peer_hash_key(local, &peer->srx); - rxrpc_init_peer(rx, peer, hash_key); + rxrpc_init_peer(local, peer, hash_key); spin_lock(&rxnet->peer_hash_lock); hash_add_rcu(rxnet->peer_hash, &peer->hash_link, hash_key); @@ -328,8 +320,7 @@ void rxrpc_new_incoming_peer(struct rxrpc_sock *rx, struct rxrpc_local *local, /* * obtain a remote transport endpoint for the specified address */ -struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, - struct rxrpc_local *local, +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_local *local, struct sockaddr_rxrpc *srx, gfp_t gfp) { struct rxrpc_peer *peer, *candidate; @@ -341,7 +332,7 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, /* search the peer list first */ rcu_read_lock(); peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); - if (peer && !rxrpc_get_peer_maybe(peer)) + if (peer && !rxrpc_get_peer_maybe(peer, rxrpc_peer_get_lookup_client)) peer = NULL; rcu_read_unlock(); @@ -349,17 +340,17 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, /* The peer is not yet present in hash - create a candidate * for a new record and then redo the search. */ - candidate = rxrpc_create_peer(rx, local, srx, hash_key, gfp); + candidate = rxrpc_create_peer(local, srx, hash_key, gfp); if (!candidate) { _leave(" = NULL [nomem]"); return NULL; } - spin_lock_bh(&rxnet->peer_hash_lock); + spin_lock(&rxnet->peer_hash_lock); /* Need to check that we aren't racing with someone else */ peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); - if (peer && !rxrpc_get_peer_maybe(peer)) + if (peer && !rxrpc_get_peer_maybe(peer, rxrpc_peer_get_lookup_client)) peer = NULL; if (!peer) { hash_add_rcu(rxnet->peer_hash, @@ -368,7 +359,7 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, &rxnet->peer_keepalive_new); } - spin_unlock_bh(&rxnet->peer_hash_lock); + spin_unlock(&rxnet->peer_hash_lock); if (peer) rxrpc_free_peer(candidate); @@ -376,8 +367,6 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, peer = candidate; } - _net("PEER %d {%pISp}", peer->debug_id, &peer->srx.transport); - _leave(" = %p {u=%d}", peer, refcount_read(&peer->ref)); return peer; } @@ -385,27 +374,26 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, /* * Get a ref on a peer record. */ -struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *peer) +struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *peer, enum rxrpc_peer_trace why) { - const void *here = __builtin_return_address(0); int r; __refcount_inc(&peer->ref, &r); - trace_rxrpc_peer(peer->debug_id, rxrpc_peer_got, r + 1, here); + trace_rxrpc_peer(peer->debug_id, r + 1, why); return peer; } /* * Get a ref on a peer record unless its usage has already reached 0. */ -struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *peer) +struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *peer, + enum rxrpc_peer_trace why) { - const void *here = __builtin_return_address(0); int r; if (peer) { if (__refcount_inc_not_zero(&peer->ref, &r)) - trace_rxrpc_peer(peer->debug_id, rxrpc_peer_got, r + 1, here); + trace_rxrpc_peer(peer->debug_id, r + 1, why); else peer = NULL; } @@ -421,10 +409,10 @@ static void __rxrpc_put_peer(struct rxrpc_peer *peer) ASSERT(hlist_empty(&peer->error_targets)); - spin_lock_bh(&rxnet->peer_hash_lock); + spin_lock(&rxnet->peer_hash_lock); hash_del_rcu(&peer->hash_link); list_del_init(&peer->keepalive_link); - spin_unlock_bh(&rxnet->peer_hash_lock); + spin_unlock(&rxnet->peer_hash_lock); rxrpc_free_peer(peer); } @@ -432,9 +420,8 @@ static void __rxrpc_put_peer(struct rxrpc_peer *peer) /* * Drop a ref on a peer record. */ -void rxrpc_put_peer(struct rxrpc_peer *peer) +void rxrpc_put_peer(struct rxrpc_peer *peer, enum rxrpc_peer_trace why) { - const void *here = __builtin_return_address(0); unsigned int debug_id; bool dead; int r; @@ -442,33 +429,13 @@ void rxrpc_put_peer(struct rxrpc_peer *peer) if (peer) { debug_id = peer->debug_id; dead = __refcount_dec_and_test(&peer->ref, &r); - trace_rxrpc_peer(debug_id, rxrpc_peer_put, r - 1, here); + trace_rxrpc_peer(debug_id, r - 1, why); if (dead) __rxrpc_put_peer(peer); } } /* - * Drop a ref on a peer record where the caller already holds the - * peer_hash_lock. - */ -void rxrpc_put_peer_locked(struct rxrpc_peer *peer) -{ - const void *here = __builtin_return_address(0); - unsigned int debug_id = peer->debug_id; - bool dead; - int r; - - dead = __refcount_dec_and_test(&peer->ref, &r); - trace_rxrpc_peer(debug_id, rxrpc_peer_put, r - 1, here); - if (dead) { - hash_del_rcu(&peer->hash_link); - list_del_init(&peer->keepalive_link); - rxrpc_free_peer(peer); - } -} - -/* * Make sure all peer records have been discarded. */ void rxrpc_destroy_all_peers(struct rxrpc_net *rxnet) diff --git a/net/rxrpc/proc.c b/net/rxrpc/proc.c index 245418943e01..750158a085cd 100644 --- a/net/rxrpc/proc.c +++ b/net/rxrpc/proc.c @@ -12,13 +12,13 @@ static const char *const rxrpc_conn_states[RXRPC_CONN__NR_STATES] = { [RXRPC_CONN_UNUSED] = "Unused ", + [RXRPC_CONN_CLIENT_UNSECURED] = "ClUnsec ", [RXRPC_CONN_CLIENT] = "Client ", [RXRPC_CONN_SERVICE_PREALLOC] = "SvPrealc", [RXRPC_CONN_SERVICE_UNSECURED] = "SvUnsec ", [RXRPC_CONN_SERVICE_CHALLENGING] = "SvChall ", [RXRPC_CONN_SERVICE] = "SvSecure", - [RXRPC_CONN_REMOTELY_ABORTED] = "RmtAbort", - [RXRPC_CONN_LOCALLY_ABORTED] = "LocAbort", + [RXRPC_CONN_ABORTED] = "Aborted ", }; /* @@ -49,65 +49,58 @@ static void rxrpc_call_seq_stop(struct seq_file *seq, void *v) static int rxrpc_call_seq_show(struct seq_file *seq, void *v) { struct rxrpc_local *local; - struct rxrpc_sock *rx; - struct rxrpc_peer *peer; struct rxrpc_call *call; struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + enum rxrpc_call_state state; unsigned long timeout = 0; - rxrpc_seq_t tx_hard_ack, rx_hard_ack; + rxrpc_seq_t acks_hard_ack; char lbuff[50], rbuff[50]; + u64 wtmp; if (v == &rxnet->calls) { seq_puts(seq, "Proto Local " " Remote " " SvID ConnID CallID End Use State Abort " - " DebugId TxSeq TW RxSeq RW RxSerial RxTimo\n"); + " DebugId TxSeq TW RxSeq RW RxSerial CW RxTimo\n"); return 0; } call = list_entry(v, struct rxrpc_call, link); - rx = rcu_dereference(call->socket); - if (rx) { - local = READ_ONCE(rx->local); - if (local) - sprintf(lbuff, "%pISpc", &local->srx.transport); - else - strcpy(lbuff, "no_local"); - } else { - strcpy(lbuff, "no_socket"); - } - - peer = call->peer; - if (peer) - sprintf(rbuff, "%pISpc", &peer->srx.transport); + local = call->local; + if (local) + sprintf(lbuff, "%pISpc", &local->srx.transport); else - strcpy(rbuff, "no_connection"); + strcpy(lbuff, "no_local"); + + sprintf(rbuff, "%pISpc", &call->dest_srx.transport); - if (call->state != RXRPC_CALL_SERVER_PREALLOC) { + state = rxrpc_call_state(call); + if (state != RXRPC_CALL_SERVER_PREALLOC) { timeout = READ_ONCE(call->expect_rx_by); timeout -= jiffies; } - tx_hard_ack = READ_ONCE(call->tx_hard_ack); - rx_hard_ack = READ_ONCE(call->rx_hard_ack); + acks_hard_ack = READ_ONCE(call->acks_hard_ack); + wtmp = atomic64_read_acquire(&call->ackr_window); seq_printf(seq, "UDP %-47.47s %-47.47s %4x %08x %08x %s %3u" - " %-8.8s %08x %08x %08x %02x %08x %02x %08x %06lx\n", + " %-8.8s %08x %08x %08x %02x %08x %02x %08x %02x %06lx\n", lbuff, rbuff, - call->service_id, + call->dest_srx.srx_service, call->cid, call->call_id, rxrpc_is_service_call(call) ? "Svc" : "Clt", refcount_read(&call->ref), - rxrpc_call_states[call->state], + rxrpc_call_states[state], call->abort_code, call->debug_id, - tx_hard_ack, READ_ONCE(call->tx_top) - tx_hard_ack, - rx_hard_ack, READ_ONCE(call->rx_top) - rx_hard_ack, + acks_hard_ack, READ_ONCE(call->tx_top) - acks_hard_ack, + lower_32_bits(wtmp), upper_32_bits(wtmp) - lower_32_bits(wtmp), call->rx_serial, + call->cong_cwnd, timeout); return 0; @@ -152,13 +145,14 @@ static int rxrpc_connection_seq_show(struct seq_file *seq, void *v) { struct rxrpc_connection *conn; struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + const char *state; char lbuff[50], rbuff[50]; if (v == &rxnet->conn_proc_list) { seq_puts(seq, "Proto Local " " Remote " - " SvID ConnID End Use State Key " + " SvID ConnID End Ref Act State Key " " Serial ISerial CallId0 CallId1 CallId2 CallId3\n" ); return 0; @@ -171,12 +165,14 @@ static int rxrpc_connection_seq_show(struct seq_file *seq, void *v) goto print; } - sprintf(lbuff, "%pISpc", &conn->params.local->srx.transport); - - sprintf(rbuff, "%pISpc", &conn->params.peer->srx.transport); + sprintf(lbuff, "%pISpc", &conn->local->srx.transport); + sprintf(rbuff, "%pISpc", &conn->peer->srx.transport); print: + state = rxrpc_is_conn_aborted(conn) ? + rxrpc_call_completions[conn->completion] : + rxrpc_conn_states[conn->state]; seq_printf(seq, - "UDP %-47.47s %-47.47s %4x %08x %s %3u" + "UDP %-47.47s %-47.47s %4x %08x %s %3u %3d" " %s %08x %08x %08x %08x %08x %08x %08x\n", lbuff, rbuff, @@ -184,8 +180,9 @@ print: conn->proto.cid, rxrpc_conn_is_service(conn) ? "Svc" : "Clt", refcount_read(&conn->ref), - rxrpc_conn_states[conn->state], - key_serial(conn->params.key), + atomic_read(&conn->active), + state, + key_serial(conn->key), atomic_read(&conn->serial), conn->hi_serial, conn->channels[0].call_id, @@ -216,7 +213,7 @@ static int rxrpc_peer_seq_show(struct seq_file *seq, void *v) seq_puts(seq, "Proto Local " " Remote " - " Use CW MTU LastUse RTT RTO\n" + " Use SST MTU LastUse RTT RTO\n" ); return 0; } @@ -234,7 +231,7 @@ static int rxrpc_peer_seq_show(struct seq_file *seq, void *v) lbuff, rbuff, refcount_read(&peer->ref), - peer->cong_cwnd, + peer->cong_ssthresh, peer->mtu, now - peer->last_tx_at, peer->srtt_us >> 3, @@ -340,7 +337,7 @@ static int rxrpc_local_seq_show(struct seq_file *seq, void *v) if (v == SEQ_START_TOKEN) { seq_puts(seq, "Proto Local " - " Use Act\n"); + " Use Act RxQ\n"); return 0; } @@ -349,10 +346,11 @@ static int rxrpc_local_seq_show(struct seq_file *seq, void *v) sprintf(lbuff, "%pISpc", &local->srx.transport); seq_printf(seq, - "UDP %-47.47s %3u %3u\n", + "UDP %-47.47s %3u %3u %3u\n", lbuff, refcount_read(&local->ref), - atomic_read(&local->active_users)); + atomic_read(&local->active_users), + local->rx_queue.qlen); return 0; } @@ -397,3 +395,109 @@ const struct seq_operations rxrpc_local_seq_ops = { .stop = rxrpc_local_seq_stop, .show = rxrpc_local_seq_show, }; + +/* + * Display stats in /proc/net/rxrpc/stats + */ +int rxrpc_stats_show(struct seq_file *seq, void *v) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_single_net(seq)); + + seq_printf(seq, + "Data : send=%u sendf=%u fail=%u\n", + atomic_read(&rxnet->stat_tx_data_send), + atomic_read(&rxnet->stat_tx_data_send_frag), + atomic_read(&rxnet->stat_tx_data_send_fail)); + seq_printf(seq, + "Data-Tx : nr=%u retrans=%u uf=%u cwr=%u\n", + atomic_read(&rxnet->stat_tx_data), + atomic_read(&rxnet->stat_tx_data_retrans), + atomic_read(&rxnet->stat_tx_data_underflow), + atomic_read(&rxnet->stat_tx_data_cwnd_reset)); + seq_printf(seq, + "Data-Rx : nr=%u reqack=%u jumbo=%u\n", + atomic_read(&rxnet->stat_rx_data), + atomic_read(&rxnet->stat_rx_data_reqack), + atomic_read(&rxnet->stat_rx_data_jumbo)); + seq_printf(seq, + "Ack : fill=%u send=%u skip=%u\n", + atomic_read(&rxnet->stat_tx_ack_fill), + atomic_read(&rxnet->stat_tx_ack_send), + atomic_read(&rxnet->stat_tx_ack_skip)); + seq_printf(seq, + "Ack-Tx : req=%u dup=%u oos=%u exw=%u nos=%u png=%u prs=%u dly=%u idl=%u\n", + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_REQUESTED]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_DUPLICATE]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_OUT_OF_SEQUENCE]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_EXCEEDS_WINDOW]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_NOSPACE]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_PING]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_PING_RESPONSE]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_DELAY]), + atomic_read(&rxnet->stat_tx_acks[RXRPC_ACK_IDLE])); + seq_printf(seq, + "Ack-Rx : req=%u dup=%u oos=%u exw=%u nos=%u png=%u prs=%u dly=%u idl=%u\n", + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_REQUESTED]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_DUPLICATE]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_OUT_OF_SEQUENCE]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_EXCEEDS_WINDOW]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_NOSPACE]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_PING]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_PING_RESPONSE]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_DELAY]), + atomic_read(&rxnet->stat_rx_acks[RXRPC_ACK_IDLE])); + seq_printf(seq, + "Why-Req-A: acklost=%u already=%u mrtt=%u ortt=%u\n", + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_ack_lost]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_already_on]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_more_rtt]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_old_rtt])); + seq_printf(seq, + "Why-Req-A: nolast=%u retx=%u slows=%u smtxw=%u\n", + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_no_srv_last]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_retrans]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_slow_start]), + atomic_read(&rxnet->stat_why_req_ack[rxrpc_reqack_small_txwin])); + seq_printf(seq, + "Buffers : txb=%u rxb=%u\n", + atomic_read(&rxrpc_nr_txbuf), + atomic_read(&rxrpc_n_rx_skbs)); + seq_printf(seq, + "IO-thread: loops=%u\n", + atomic_read(&rxnet->stat_io_loop)); + return 0; +} + +/* + * Clear stats if /proc/net/rxrpc/stats is written to. + */ +int rxrpc_stats_clear(struct file *file, char *buf, size_t size) +{ + struct seq_file *m = file->private_data; + struct rxrpc_net *rxnet = rxrpc_net(seq_file_single_net(m)); + + if (size > 1 || (size == 1 && buf[0] != '\n')) + return -EINVAL; + + atomic_set(&rxnet->stat_tx_data, 0); + atomic_set(&rxnet->stat_tx_data_retrans, 0); + atomic_set(&rxnet->stat_tx_data_underflow, 0); + atomic_set(&rxnet->stat_tx_data_cwnd_reset, 0); + atomic_set(&rxnet->stat_tx_data_send, 0); + atomic_set(&rxnet->stat_tx_data_send_frag, 0); + atomic_set(&rxnet->stat_tx_data_send_fail, 0); + atomic_set(&rxnet->stat_rx_data, 0); + atomic_set(&rxnet->stat_rx_data_reqack, 0); + atomic_set(&rxnet->stat_rx_data_jumbo, 0); + + atomic_set(&rxnet->stat_tx_ack_fill, 0); + atomic_set(&rxnet->stat_tx_ack_send, 0); + atomic_set(&rxnet->stat_tx_ack_skip, 0); + memset(&rxnet->stat_tx_acks, 0, sizeof(rxnet->stat_tx_acks)); + memset(&rxnet->stat_rx_acks, 0, sizeof(rxnet->stat_rx_acks)); + + memset(&rxnet->stat_why_req_ack, 0, sizeof(rxnet->stat_why_req_ack)); + + atomic_set(&rxnet->stat_io_loop, 0); + return size; +} diff --git a/net/rxrpc/protocol.h b/net/rxrpc/protocol.h index d2cf8e1d218f..6760cb99c6d6 100644 --- a/net/rxrpc/protocol.h +++ b/net/rxrpc/protocol.h @@ -84,7 +84,7 @@ struct rxrpc_jumbo_header { __be16 _rsvd; /* reserved */ __be16 cksum; /* kerberos security checksum */ }; -}; +} __packed; #define RXRPC_JUMBO_DATALEN 1412 /* non-terminal jumbo packet data length */ #define RXRPC_JUMBO_SUBPKTLEN (RXRPC_JUMBO_DATALEN + sizeof(struct rxrpc_jumbo_header)) @@ -132,13 +132,6 @@ struct rxrpc_ackpacket { } __packed; -/* Some ACKs refer to specific packets and some are general and can be updated. */ -#define RXRPC_ACK_UPDATEABLE ((1 << RXRPC_ACK_REQUESTED) | \ - (1 << RXRPC_ACK_PING_RESPONSE) | \ - (1 << RXRPC_ACK_DELAY) | \ - (1 << RXRPC_ACK_IDLE)) - - /* * ACK packets can have a further piece of information tagged on the end */ diff --git a/net/rxrpc/recvmsg.c b/net/rxrpc/recvmsg.c index 7e39c262fd79..dd54ceee7bcc 100644 --- a/net/rxrpc/recvmsg.c +++ b/net/rxrpc/recvmsg.c @@ -36,16 +36,16 @@ void rxrpc_notify_socket(struct rxrpc_call *call) sk = &rx->sk; if (rx && sk->sk_state < RXRPC_CLOSE) { if (call->notify_rx) { - spin_lock_bh(&call->notify_lock); + spin_lock(&call->notify_lock); call->notify_rx(sk, call, call->user_call_ID); - spin_unlock_bh(&call->notify_lock); + spin_unlock(&call->notify_lock); } else { - write_lock_bh(&rx->recvmsg_lock); + write_lock(&rx->recvmsg_lock); if (list_empty(&call->recvmsg_link)) { - rxrpc_get_call(call, rxrpc_call_got); + rxrpc_get_call(call, rxrpc_call_get_notify_socket); list_add_tail(&call->recvmsg_link, &rx->recvmsg_q); } - write_unlock_bh(&rx->recvmsg_lock); + write_unlock(&rx->recvmsg_lock); if (!sock_flag(sk, SOCK_DEAD)) { _debug("call %ps", sk->sk_data_ready); @@ -59,85 +59,6 @@ void rxrpc_notify_socket(struct rxrpc_call *call) } /* - * Transition a call to the complete state. - */ -bool __rxrpc_set_call_completion(struct rxrpc_call *call, - enum rxrpc_call_completion compl, - u32 abort_code, - int error) -{ - if (call->state < RXRPC_CALL_COMPLETE) { - call->abort_code = abort_code; - call->error = error; - call->completion = compl; - call->state = RXRPC_CALL_COMPLETE; - trace_rxrpc_call_complete(call); - wake_up(&call->waitq); - rxrpc_notify_socket(call); - return true; - } - return false; -} - -bool rxrpc_set_call_completion(struct rxrpc_call *call, - enum rxrpc_call_completion compl, - u32 abort_code, - int error) -{ - bool ret = false; - - if (call->state < RXRPC_CALL_COMPLETE) { - write_lock_bh(&call->state_lock); - ret = __rxrpc_set_call_completion(call, compl, abort_code, error); - write_unlock_bh(&call->state_lock); - } - return ret; -} - -/* - * Record that a call successfully completed. - */ -bool __rxrpc_call_completed(struct rxrpc_call *call) -{ - return __rxrpc_set_call_completion(call, RXRPC_CALL_SUCCEEDED, 0, 0); -} - -bool rxrpc_call_completed(struct rxrpc_call *call) -{ - bool ret = false; - - if (call->state < RXRPC_CALL_COMPLETE) { - write_lock_bh(&call->state_lock); - ret = __rxrpc_call_completed(call); - write_unlock_bh(&call->state_lock); - } - return ret; -} - -/* - * Record that a call is locally aborted. - */ -bool __rxrpc_abort_call(const char *why, struct rxrpc_call *call, - rxrpc_seq_t seq, u32 abort_code, int error) -{ - trace_rxrpc_abort(call->debug_id, why, call->cid, call->call_id, seq, - abort_code, error); - return __rxrpc_set_call_completion(call, RXRPC_CALL_LOCALLY_ABORTED, - abort_code, error); -} - -bool rxrpc_abort_call(const char *why, struct rxrpc_call *call, - rxrpc_seq_t seq, u32 abort_code, int error) -{ - bool ret; - - write_lock_bh(&call->state_lock); - ret = __rxrpc_abort_call(why, call, seq, abort_code, error); - write_unlock_bh(&call->state_lock); - return ret; -} - -/* * Pass a call terminating message to userspace. */ static int rxrpc_recvmsg_term(struct rxrpc_call *call, struct msghdr *msg) @@ -168,55 +89,18 @@ static int rxrpc_recvmsg_term(struct rxrpc_call *call, struct msghdr *msg) ret = put_cmsg(msg, SOL_RXRPC, RXRPC_LOCAL_ERROR, 4, &tmp); break; default: - pr_err("Invalid terminal call state %u\n", call->state); + pr_err("Invalid terminal call state %u\n", call->completion); BUG(); break; } - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_terminal, call->rx_hard_ack, - call->rx_pkt_offset, call->rx_pkt_len, ret); + trace_rxrpc_recvdata(call, rxrpc_recvmsg_terminal, + lower_32_bits(atomic64_read(&call->ackr_window)) - 1, + call->rx_pkt_offset, call->rx_pkt_len, ret); return ret; } /* - * End the packet reception phase. - */ -static void rxrpc_end_rx_phase(struct rxrpc_call *call, rxrpc_serial_t serial) -{ - _enter("%d,%s", call->debug_id, rxrpc_call_states[call->state]); - - trace_rxrpc_receive(call, rxrpc_receive_end, 0, call->rx_top); - ASSERTCMP(call->rx_hard_ack, ==, call->rx_top); - - if (call->state == RXRPC_CALL_CLIENT_RECV_REPLY) { - rxrpc_propose_ACK(call, RXRPC_ACK_IDLE, serial, false, true, - rxrpc_propose_ack_terminal_ack); - //rxrpc_send_ack_packet(call, false, NULL); - } - - write_lock_bh(&call->state_lock); - - switch (call->state) { - case RXRPC_CALL_CLIENT_RECV_REPLY: - __rxrpc_call_completed(call); - write_unlock_bh(&call->state_lock); - break; - - case RXRPC_CALL_SERVER_RECV_REQUEST: - call->tx_phase = true; - call->state = RXRPC_CALL_SERVER_ACK_REQUEST; - call->expect_req_by = jiffies + MAX_JIFFY_OFFSET; - write_unlock_bh(&call->state_lock); - rxrpc_propose_ACK(call, RXRPC_ACK_DELAY, serial, false, true, - rxrpc_propose_ack_processing_op); - break; - default: - write_unlock_bh(&call->state_lock); - break; - } -} - -/* * Discard a packet we've used up and advance the Rx window by one. */ static void rxrpc_rotate_rx_window(struct rxrpc_call *call) @@ -224,132 +108,57 @@ static void rxrpc_rotate_rx_window(struct rxrpc_call *call) struct rxrpc_skb_priv *sp; struct sk_buff *skb; rxrpc_serial_t serial; - rxrpc_seq_t hard_ack, top; - bool last = false; - u8 subpacket; - int ix; + rxrpc_seq_t old_consumed = call->rx_consumed, tseq; + bool last; + int acked; _enter("%d", call->debug_id); - hard_ack = call->rx_hard_ack; - top = smp_load_acquire(&call->rx_top); - ASSERT(before(hard_ack, top)); + skb = skb_dequeue(&call->recvmsg_queue); + rxrpc_see_skb(skb, rxrpc_skb_see_rotate); - hard_ack++; - ix = hard_ack & RXRPC_RXTX_BUFF_MASK; - skb = call->rxtx_buffer[ix]; - rxrpc_see_skb(skb, rxrpc_skb_rotated); sp = rxrpc_skb(skb); + tseq = sp->hdr.seq; + serial = sp->hdr.serial; + last = sp->hdr.flags & RXRPC_LAST_PACKET; - subpacket = call->rxtx_annotations[ix] & RXRPC_RX_ANNO_SUBPACKET; - serial = sp->hdr.serial + subpacket; - - if (subpacket == sp->nr_subpackets - 1 && - sp->rx_flags & RXRPC_SKB_INCL_LAST) - last = true; - - call->rxtx_buffer[ix] = NULL; - call->rxtx_annotations[ix] = 0; /* Barrier against rxrpc_input_data(). */ - smp_store_release(&call->rx_hard_ack, hard_ack); - - rxrpc_free_skb(skb, rxrpc_skb_freed); - - trace_rxrpc_receive(call, rxrpc_receive_rotate, serial, hard_ack); - if (last) { - rxrpc_end_rx_phase(call, serial); - } else { - /* Check to see if there's an ACK that needs sending. */ - if (atomic_inc_return(&call->ackr_nr_consumed) > 2) - rxrpc_propose_ACK(call, RXRPC_ACK_IDLE, serial, - true, false, - rxrpc_propose_ack_rotate_rx); - if (call->ackr_reason && call->ackr_reason != RXRPC_ACK_DELAY) - rxrpc_send_ack_packet(call, false, NULL); - } -} + if (after(tseq, call->rx_consumed)) + smp_store_release(&call->rx_consumed, tseq); -/* - * Decrypt and verify a (sub)packet. The packet's length may be changed due to - * padding, but if this is the case, the packet length will be resident in the - * socket buffer. Note that we can't modify the master skb info as the skb may - * be the home to multiple subpackets. - */ -static int rxrpc_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, - u8 annotation, - unsigned int offset, unsigned int len) -{ - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - rxrpc_seq_t seq = sp->hdr.seq; - u16 cksum = sp->hdr.cksum; - u8 subpacket = annotation & RXRPC_RX_ANNO_SUBPACKET; + rxrpc_free_skb(skb, rxrpc_skb_put_rotate); - _enter(""); + trace_rxrpc_receive(call, last ? rxrpc_receive_rotate_last : rxrpc_receive_rotate, + serial, call->rx_consumed); - /* For all but the head jumbo subpacket, the security checksum is in a - * jumbo header immediately prior to the data. - */ - if (subpacket > 0) { - __be16 tmp; - if (skb_copy_bits(skb, offset - 2, &tmp, 2) < 0) - BUG(); - cksum = ntohs(tmp); - seq += subpacket; - } + if (last) + set_bit(RXRPC_CALL_RECVMSG_READ_ALL, &call->flags); - return call->security->verify_packet(call, skb, offset, len, - seq, cksum); + /* Check to see if there's an ACK that needs sending. */ + acked = atomic_add_return(call->rx_consumed - old_consumed, + &call->ackr_nr_consumed); + if (acked > 2 && + !test_and_set_bit(RXRPC_CALL_RX_IS_IDLE, &call->flags)) + rxrpc_poke_call(call, rxrpc_call_poke_idle); } /* - * Locate the data within a packet. This is complicated by: - * - * (1) An skb may contain a jumbo packet - so we have to find the appropriate - * subpacket. - * - * (2) The (sub)packets may be encrypted and, if so, the encrypted portion - * contains an extra header which includes the true length of the data, - * excluding any encrypted padding. + * Decrypt and verify a DATA packet. */ -static int rxrpc_locate_data(struct rxrpc_call *call, struct sk_buff *skb, - u8 *_annotation, - unsigned int *_offset, unsigned int *_len, - bool *_last) +static int rxrpc_verify_data(struct rxrpc_call *call, struct sk_buff *skb) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - unsigned int offset = sizeof(struct rxrpc_wire_header); - unsigned int len; - bool last = false; - int ret; - u8 annotation = *_annotation; - u8 subpacket = annotation & RXRPC_RX_ANNO_SUBPACKET; - - /* Locate the subpacket */ - offset += subpacket * RXRPC_JUMBO_SUBPKTLEN; - len = skb->len - offset; - if (subpacket < sp->nr_subpackets - 1) - len = RXRPC_JUMBO_DATALEN; - else if (sp->rx_flags & RXRPC_SKB_INCL_LAST) - last = true; - - if (!(annotation & RXRPC_RX_ANNO_VERIFIED)) { - ret = rxrpc_verify_packet(call, skb, annotation, offset, len); - if (ret < 0) - return ret; - *_annotation |= RXRPC_RX_ANNO_VERIFIED; - } - *_offset = offset; - *_len = len; - *_last = last; - call->security->locate_data(call, skb, _offset, _len); - return 0; + if (sp->flags & RXRPC_RX_VERIFIED) + return 0; + return call->security->verify_packet(call, skb); } /* * Deliver messages to a call. This keeps processing packets until the buffer * is filled and we find either more DATA (returns 0) or the end of the DATA - * (returns 1). If more packets are required, it returns -EAGAIN. + * (returns 1). If more packets are required, it returns -EAGAIN and if the + * call has failed it returns -EIO. */ static int rxrpc_recvmsg_data(struct socket *sock, struct rxrpc_call *call, struct msghdr *msg, struct iov_iter *iter, @@ -357,69 +166,56 @@ static int rxrpc_recvmsg_data(struct socket *sock, struct rxrpc_call *call, { struct rxrpc_skb_priv *sp; struct sk_buff *skb; - rxrpc_serial_t serial; - rxrpc_seq_t hard_ack, top, seq; + rxrpc_seq_t seq = 0; size_t remain; - bool rx_pkt_last; unsigned int rx_pkt_offset, rx_pkt_len; - int ix, copy, ret = -EAGAIN, ret2; - - if (test_and_clear_bit(RXRPC_CALL_RX_UNDERRUN, &call->flags) && - call->ackr_reason) - rxrpc_send_ack_packet(call, false, NULL); + int copy, ret = -EAGAIN, ret2; rx_pkt_offset = call->rx_pkt_offset; rx_pkt_len = call->rx_pkt_len; - rx_pkt_last = call->rx_pkt_last; - if (call->state >= RXRPC_CALL_SERVER_ACK_REQUEST) { - seq = call->rx_hard_ack; + if (rxrpc_call_has_failed(call)) { + seq = lower_32_bits(atomic64_read(&call->ackr_window)) - 1; + ret = -EIO; + goto done; + } + + if (test_bit(RXRPC_CALL_RECVMSG_READ_ALL, &call->flags)) { + seq = lower_32_bits(atomic64_read(&call->ackr_window)) - 1; ret = 1; goto done; } - /* Barriers against rxrpc_input_data(). */ - hard_ack = call->rx_hard_ack; - seq = hard_ack + 1; - - while (top = smp_load_acquire(&call->rx_top), - before_eq(seq, top) - ) { - ix = seq & RXRPC_RXTX_BUFF_MASK; - skb = call->rxtx_buffer[ix]; - if (!skb) { - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_hole, seq, - rx_pkt_offset, rx_pkt_len, 0); - break; - } - smp_rmb(); - rxrpc_see_skb(skb, rxrpc_skb_seen); + /* No one else can be removing stuff from the queue, so we shouldn't + * need the Rx lock to walk it. + */ + skb = skb_peek(&call->recvmsg_queue); + while (skb) { + rxrpc_see_skb(skb, rxrpc_skb_see_recvmsg); sp = rxrpc_skb(skb); + seq = sp->hdr.seq; - if (!(flags & MSG_PEEK)) { - serial = sp->hdr.serial; - serial += call->rxtx_annotations[ix] & RXRPC_RX_ANNO_SUBPACKET; + if (!(flags & MSG_PEEK)) trace_rxrpc_receive(call, rxrpc_receive_front, - serial, seq); - } + sp->hdr.serial, seq); if (msg) sock_recv_timestamp(msg, sock->sk, skb); if (rx_pkt_offset == 0) { - ret2 = rxrpc_locate_data(call, skb, - &call->rxtx_annotations[ix], - &rx_pkt_offset, &rx_pkt_len, - &rx_pkt_last); - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_next, seq, - rx_pkt_offset, rx_pkt_len, ret2); + ret2 = rxrpc_verify_data(call, skb); + trace_rxrpc_recvdata(call, rxrpc_recvmsg_next, seq, + sp->offset, sp->len, ret2); if (ret2 < 0) { + kdebug("verify = %d", ret2); ret = ret2; goto out; } + rx_pkt_offset = sp->offset; + rx_pkt_len = sp->len; } else { - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_cont, seq, - rx_pkt_offset, rx_pkt_len, 0); + trace_rxrpc_recvdata(call, rxrpc_recvmsg_cont, seq, + rx_pkt_offset, rx_pkt_len, 0); } /* We have to handle short, empty and used-up DATA packets. */ @@ -442,39 +238,35 @@ static int rxrpc_recvmsg_data(struct socket *sock, struct rxrpc_call *call, } if (rx_pkt_len > 0) { - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_full, seq, - rx_pkt_offset, rx_pkt_len, 0); + trace_rxrpc_recvdata(call, rxrpc_recvmsg_full, seq, + rx_pkt_offset, rx_pkt_len, 0); ASSERTCMP(*_offset, ==, len); ret = 0; break; } /* The whole packet has been transferred. */ - if (!(flags & MSG_PEEK)) - rxrpc_rotate_rx_window(call); + if (sp->hdr.flags & RXRPC_LAST_PACKET) + ret = 1; rx_pkt_offset = 0; rx_pkt_len = 0; - if (rx_pkt_last) { - ASSERTCMP(seq, ==, READ_ONCE(call->rx_top)); - ret = 1; - goto out; - } + skb = skb_peek_next(skb, &call->recvmsg_queue); - seq++; + if (!(flags & MSG_PEEK)) + rxrpc_rotate_rx_window(call); } out: if (!(flags & MSG_PEEK)) { call->rx_pkt_offset = rx_pkt_offset; call->rx_pkt_len = rx_pkt_len; - call->rx_pkt_last = rx_pkt_last; } done: - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_data_return, seq, - rx_pkt_offset, rx_pkt_len, ret); + trace_rxrpc_recvdata(call, rxrpc_recvmsg_data_return, seq, + rx_pkt_offset, rx_pkt_len, ret); if (ret == -EAGAIN) - set_bit(RXRPC_CALL_RX_UNDERRUN, &call->flags); + set_bit(RXRPC_CALL_RX_IS_IDLE, &call->flags); return ret; } @@ -489,13 +281,14 @@ int rxrpc_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, struct rxrpc_call *call; struct rxrpc_sock *rx = rxrpc_sk(sock->sk); struct list_head *l; + unsigned int call_debug_id = 0; size_t copied = 0; long timeo; int ret; DEFINE_WAIT(wait); - trace_rxrpc_recvmsg(NULL, rxrpc_recvmsg_enter, 0, 0, 0, 0); + trace_rxrpc_recvmsg(0, rxrpc_recvmsg_enter, 0); if (flags & (MSG_OOB | MSG_TRUNC)) return -EOPNOTSUPP; @@ -532,8 +325,7 @@ try_again: if (list_empty(&rx->recvmsg_q)) { if (signal_pending(current)) goto wait_interrupted; - trace_rxrpc_recvmsg(NULL, rxrpc_recvmsg_wait, - 0, 0, 0, 0); + trace_rxrpc_recvmsg(0, rxrpc_recvmsg_wait, 0); timeo = schedule_timeout(timeo); } finish_wait(sk_sleep(&rx->sk), &wait); @@ -543,16 +335,17 @@ try_again: /* Find the next call and dequeue it if we're not just peeking. If we * do dequeue it, that comes with a ref that we will need to release. */ - write_lock_bh(&rx->recvmsg_lock); + write_lock(&rx->recvmsg_lock); l = rx->recvmsg_q.next; call = list_entry(l, struct rxrpc_call, recvmsg_link); if (!(flags & MSG_PEEK)) list_del_init(&call->recvmsg_link); else - rxrpc_get_call(call, rxrpc_call_got); - write_unlock_bh(&rx->recvmsg_lock); + rxrpc_get_call(call, rxrpc_call_get_recvmsg); + write_unlock(&rx->recvmsg_lock); - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_dequeue, 0, 0, 0, 0); + call_debug_id = call->debug_id; + trace_rxrpc_recvmsg(call_debug_id, rxrpc_recvmsg_dequeue, 0); /* We're going to drop the socket lock, so we need to lock the call * against interference by sendmsg. @@ -588,45 +381,42 @@ try_again: } if (msg->msg_name && call->peer) { - struct sockaddr_rxrpc *srx = msg->msg_name; - size_t len = sizeof(call->peer->srx); + size_t len = sizeof(call->dest_srx); - memcpy(msg->msg_name, &call->peer->srx, len); - srx->srx_service = call->service_id; + memcpy(msg->msg_name, &call->dest_srx, len); msg->msg_namelen = len; } - switch (READ_ONCE(call->state)) { - case RXRPC_CALL_CLIENT_RECV_REPLY: - case RXRPC_CALL_SERVER_RECV_REQUEST: - case RXRPC_CALL_SERVER_ACK_REQUEST: - ret = rxrpc_recvmsg_data(sock, call, msg, &msg->msg_iter, len, - flags, &copied); - if (ret == -EAGAIN) - ret = 0; - - if (after(call->rx_top, call->rx_hard_ack) && - call->rxtx_buffer[(call->rx_hard_ack + 1) & RXRPC_RXTX_BUFF_MASK]) - rxrpc_notify_socket(call); - break; - default: + ret = rxrpc_recvmsg_data(sock, call, msg, &msg->msg_iter, len, + flags, &copied); + if (ret == -EAGAIN) ret = 0; - break; - } - + if (ret == -EIO) + goto call_failed; if (ret < 0) goto error_unlock_call; - if (call->state == RXRPC_CALL_COMPLETE) { - ret = rxrpc_recvmsg_term(call, msg); - if (ret < 0) - goto error_unlock_call; - if (!(flags & MSG_PEEK)) - rxrpc_release_call(rx, call); - msg->msg_flags |= MSG_EOR; - ret = 1; - } + if (rxrpc_call_is_complete(call) && + skb_queue_empty(&call->recvmsg_queue)) + goto call_complete; + if (rxrpc_call_has_failed(call)) + goto call_failed; + + rxrpc_notify_socket(call); + goto not_yet_complete; + +call_failed: + rxrpc_purge_queue(&call->recvmsg_queue); +call_complete: + ret = rxrpc_recvmsg_term(call, msg); + if (ret < 0) + goto error_unlock_call; + if (!(flags & MSG_PEEK)) + rxrpc_release_call(rx, call); + msg->msg_flags |= MSG_EOR; + ret = 1; +not_yet_complete: if (ret == 0) msg->msg_flags |= MSG_MORE; else @@ -635,23 +425,23 @@ try_again: error_unlock_call: mutex_unlock(&call->user_mutex); - rxrpc_put_call(call, rxrpc_call_put); - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_return, 0, 0, 0, ret); + rxrpc_put_call(call, rxrpc_call_put_recvmsg); + trace_rxrpc_recvmsg(call_debug_id, rxrpc_recvmsg_return, ret); return ret; error_requeue_call: if (!(flags & MSG_PEEK)) { - write_lock_bh(&rx->recvmsg_lock); + write_lock(&rx->recvmsg_lock); list_add(&call->recvmsg_link, &rx->recvmsg_q); - write_unlock_bh(&rx->recvmsg_lock); - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_requeue, 0, 0, 0, 0); + write_unlock(&rx->recvmsg_lock); + trace_rxrpc_recvmsg(call_debug_id, rxrpc_recvmsg_requeue, 0); } else { - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_recvmsg); } error_no_call: release_sock(&rx->sk); error_trace: - trace_rxrpc_recvmsg(call, rxrpc_recvmsg_return, 0, 0, 0, ret); + trace_rxrpc_recvmsg(call_debug_id, rxrpc_recvmsg_return, ret); return ret; wait_interrupted: @@ -689,78 +479,56 @@ int rxrpc_kernel_recv_data(struct socket *sock, struct rxrpc_call *call, size_t offset = 0; int ret; - _enter("{%d,%s},%zu,%d", - call->debug_id, rxrpc_call_states[call->state], - *_len, want_more); - - ASSERTCMP(call->state, !=, RXRPC_CALL_SERVER_SECURING); + _enter("{%d},%zu,%d", call->debug_id, *_len, want_more); mutex_lock(&call->user_mutex); - switch (READ_ONCE(call->state)) { - case RXRPC_CALL_CLIENT_RECV_REPLY: - case RXRPC_CALL_SERVER_RECV_REQUEST: - case RXRPC_CALL_SERVER_ACK_REQUEST: - ret = rxrpc_recvmsg_data(sock, call, NULL, iter, - *_len, 0, &offset); - *_len -= offset; - if (ret < 0) - goto out; - - /* We can only reach here with a partially full buffer if we - * have reached the end of the data. We must otherwise have a - * full buffer or have been given -EAGAIN. - */ - if (ret == 1) { - if (iov_iter_count(iter) > 0) - goto short_data; - if (!want_more) - goto read_phase_complete; - ret = 0; - goto out; - } - - if (!want_more) - goto excess_data; + ret = rxrpc_recvmsg_data(sock, call, NULL, iter, *_len, 0, &offset); + *_len -= offset; + if (ret == -EIO) + goto call_failed; + if (ret < 0) goto out; - case RXRPC_CALL_COMPLETE: - goto call_complete; - - default: - ret = -EINPROGRESS; + /* We can only reach here with a partially full buffer if we have + * reached the end of the data. We must otherwise have a full buffer + * or have been given -EAGAIN. + */ + if (ret == 1) { + if (iov_iter_count(iter) > 0) + goto short_data; + if (!want_more) + goto read_phase_complete; + ret = 0; goto out; } + if (!want_more) + goto excess_data; + goto out; + read_phase_complete: ret = 1; out: - switch (call->ackr_reason) { - case RXRPC_ACK_IDLE: - break; - case RXRPC_ACK_DELAY: - if (ret != -EAGAIN) - break; - fallthrough; - default: - rxrpc_send_ack_packet(call, false, NULL); - } - if (_service) - *_service = call->service_id; + *_service = call->dest_srx.srx_service; mutex_unlock(&call->user_mutex); _leave(" = %d [%zu,%d]", ret, iov_iter_count(iter), *_abort); return ret; short_data: - trace_rxrpc_rx_eproto(call, 0, tracepoint_string("short_data")); + trace_rxrpc_abort(call->debug_id, rxrpc_recvmsg_short_data, + call->cid, call->call_id, call->rx_consumed, + 0, -EBADMSG); ret = -EBADMSG; goto out; excess_data: - trace_rxrpc_rx_eproto(call, 0, tracepoint_string("excess_data")); + trace_rxrpc_abort(call->debug_id, rxrpc_recvmsg_excess_data, + call->cid, call->call_id, call->rx_consumed, + 0, -EMSGSIZE); ret = -EMSGSIZE; goto out; -call_complete: +call_failed: *_abort = call->abort_code; ret = call->error; if (call->completion == RXRPC_CALL_SUCCEEDED) { diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c index 78fa0524156f..1bf571a66e02 100644 --- a/net/rxrpc/rxkad.c +++ b/net/rxrpc/rxkad.c @@ -103,7 +103,7 @@ static int rxkad_init_connection_security(struct rxrpc_connection *conn, struct crypto_sync_skcipher *ci; int ret; - _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key)); + _enter("{%d},{%x}", conn->debug_id, key_serial(conn->key)); conn->security_ix = token->security_index; @@ -118,7 +118,7 @@ static int rxkad_init_connection_security(struct rxrpc_connection *conn, sizeof(token->kad->session_key)) < 0) BUG(); - switch (conn->params.security_level) { + switch (conn->security_level) { case RXRPC_SECURITY_PLAIN: case RXRPC_SECURITY_AUTH: case RXRPC_SECURITY_ENCRYPT: @@ -150,7 +150,7 @@ static int rxkad_how_much_data(struct rxrpc_call *call, size_t remain, { size_t shdr, buf_size, chunk; - switch (call->conn->params.security_level) { + switch (call->conn->security_level) { default: buf_size = chunk = min_t(size_t, remain, RXRPC_JUMBO_DATALEN); shdr = 0; @@ -192,7 +192,7 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn, _enter(""); - if (!conn->params.key) + if (!conn->key) return 0; tmpbuf = kmalloc(tmpsize, GFP_KERNEL); @@ -205,7 +205,7 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn, return -ENOMEM; } - token = conn->params.key->payload.data[0]; + token = conn->key->payload.data[0]; memcpy(&iv, token->kad->session_key, sizeof(iv)); tmpbuf[0] = htonl(conn->proto.epoch); @@ -233,16 +233,8 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn, static struct skcipher_request *rxkad_get_call_crypto(struct rxrpc_call *call) { struct crypto_skcipher *tfm = &call->conn->rxkad.cipher->base; - struct skcipher_request *cipher_req = call->cipher_req; - if (!cipher_req) { - cipher_req = skcipher_request_alloc(tfm, GFP_NOFS); - if (!cipher_req) - return NULL; - call->cipher_req = cipher_req; - } - - return cipher_req; + return skcipher_request_alloc(tfm, GFP_NOFS); } /* @@ -250,20 +242,16 @@ static struct skcipher_request *rxkad_get_call_crypto(struct rxrpc_call *call) */ static void rxkad_free_call_crypto(struct rxrpc_call *call) { - if (call->cipher_req) - skcipher_request_free(call->cipher_req); - call->cipher_req = NULL; } /* * partially encrypt a packet (level 1 security) */ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, - struct sk_buff *skb, u32 data_size, + struct rxrpc_txbuf *txb, struct skcipher_request *req) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - struct rxkad_level1_hdr hdr; + struct rxkad_level1_hdr *hdr = (void *)txb->data; struct rxrpc_crypt iv; struct scatterlist sg; size_t pad; @@ -271,22 +259,22 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, _enter(""); - check = sp->hdr.seq ^ call->call_id; - data_size |= (u32)check << 16; + check = txb->seq ^ ntohl(txb->wire.callNumber); + hdr->data_size = htonl((u32)check << 16 | txb->len); - hdr.data_size = htonl(data_size); - memcpy(skb->head, &hdr, sizeof(hdr)); - - pad = sizeof(struct rxkad_level1_hdr) + data_size; + txb->len += sizeof(struct rxkad_level1_hdr); + pad = txb->len; pad = RXKAD_ALIGN - pad; pad &= RXKAD_ALIGN - 1; - if (pad) - skb_put_zero(skb, pad); + if (pad) { + memset(txb->data + txb->offset, 0, pad); + txb->len += pad; + } /* start the encryption afresh */ memset(&iv, 0, sizeof(iv)); - sg_init_one(&sg, skb->head, 8); + sg_init_one(&sg, txb->data, 8); skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); @@ -301,92 +289,68 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, * wholly encrypt a packet (level 2 security) */ static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call, - struct sk_buff *skb, - u32 data_size, + struct rxrpc_txbuf *txb, struct skcipher_request *req) { const struct rxrpc_key_token *token; - struct rxkad_level2_hdr rxkhdr; - struct rxrpc_skb_priv *sp; + struct rxkad_level2_hdr *rxkhdr = (void *)txb->data; struct rxrpc_crypt iv; - struct scatterlist sg[16]; - unsigned int len; + struct scatterlist sg; size_t pad; u16 check; - int err; - - sp = rxrpc_skb(skb); + int ret; _enter(""); - check = sp->hdr.seq ^ call->call_id; + check = txb->seq ^ ntohl(txb->wire.callNumber); - rxkhdr.data_size = htonl(data_size | (u32)check << 16); - rxkhdr.checksum = 0; - memcpy(skb->head, &rxkhdr, sizeof(rxkhdr)); + rxkhdr->data_size = htonl(txb->len | (u32)check << 16); + rxkhdr->checksum = 0; - pad = sizeof(struct rxkad_level2_hdr) + data_size; + txb->len += sizeof(struct rxkad_level2_hdr); + pad = txb->len; pad = RXKAD_ALIGN - pad; pad &= RXKAD_ALIGN - 1; - if (pad) - skb_put_zero(skb, pad); + if (pad) { + memset(txb->data + txb->offset, 0, pad); + txb->len += pad; + } /* encrypt from the session key */ - token = call->conn->params.key->payload.data[0]; + token = call->conn->key->payload.data[0]; memcpy(&iv, token->kad->session_key, sizeof(iv)); - sg_init_one(&sg[0], skb->head, sizeof(rxkhdr)); + sg_init_one(&sg, txb->data, txb->len); skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x); - crypto_skcipher_encrypt(req); - - /* we want to encrypt the skbuff in-place */ - err = -EMSGSIZE; - if (skb_shinfo(skb)->nr_frags > 16) - goto out; - - len = round_up(data_size, RXKAD_ALIGN); - - sg_init_table(sg, ARRAY_SIZE(sg)); - err = skb_to_sgvec(skb, sg, 8, len); - if (unlikely(err < 0)) - goto out; - skcipher_request_set_crypt(req, sg, sg, len, iv.x); - crypto_skcipher_encrypt(req); - - _leave(" = 0"); - err = 0; - -out: + skcipher_request_set_crypt(req, &sg, &sg, txb->len, iv.x); + ret = crypto_skcipher_encrypt(req); skcipher_request_zero(req); - return err; + return ret; } /* * checksum an RxRPC packet header */ -static int rxkad_secure_packet(struct rxrpc_call *call, - struct sk_buff *skb, - size_t data_size) +static int rxkad_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) { - struct rxrpc_skb_priv *sp; struct skcipher_request *req; struct rxrpc_crypt iv; struct scatterlist sg; + union { + __be32 buf[2]; + } crypto __aligned(8); u32 x, y; int ret; - sp = rxrpc_skb(skb); - - _enter("{%d{%x}},{#%u},%zu,", - call->debug_id, key_serial(call->conn->params.key), - sp->hdr.seq, data_size); + _enter("{%d{%x}},{#%u},%u,", + call->debug_id, key_serial(call->conn->key), + txb->seq, txb->len); if (!call->conn->rxkad.cipher) return 0; - ret = key_validate(call->conn->params.key); + ret = key_validate(call->conn->key); if (ret < 0) return ret; @@ -398,39 +362,40 @@ static int rxkad_secure_packet(struct rxrpc_call *call, memcpy(&iv, call->conn->rxkad.csum_iv.x, sizeof(iv)); /* calculate the security checksum */ - x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT); - x |= sp->hdr.seq & 0x3fffffff; - call->crypto_buf[0] = htonl(call->call_id); - call->crypto_buf[1] = htonl(x); + x = (ntohl(txb->wire.cid) & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT); + x |= txb->seq & 0x3fffffff; + crypto.buf[0] = txb->wire.callNumber; + crypto.buf[1] = htonl(x); - sg_init_one(&sg, call->crypto_buf, 8); + sg_init_one(&sg, crypto.buf, 8); skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); crypto_skcipher_encrypt(req); skcipher_request_zero(req); - y = ntohl(call->crypto_buf[1]); + y = ntohl(crypto.buf[1]); y = (y >> 16) & 0xffff; if (y == 0) y = 1; /* zero checksums are not permitted */ - sp->hdr.cksum = y; + txb->wire.cksum = htons(y); - switch (call->conn->params.security_level) { + switch (call->conn->security_level) { case RXRPC_SECURITY_PLAIN: ret = 0; break; case RXRPC_SECURITY_AUTH: - ret = rxkad_secure_packet_auth(call, skb, data_size, req); + ret = rxkad_secure_packet_auth(call, txb, req); break; case RXRPC_SECURITY_ENCRYPT: - ret = rxkad_secure_packet_encrypt(call, skb, data_size, req); + ret = rxkad_secure_packet_encrypt(call, txb, req); break; default: ret = -EPERM; break; } + skcipher_request_free(req); _leave(" = %d [set %x]", ret, y); return ret; } @@ -439,31 +404,28 @@ static int rxkad_secure_packet(struct rxrpc_call *call, * decrypt partial encryption on a packet (level 1 security) */ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int offset, unsigned int len, rxrpc_seq_t seq, struct skcipher_request *req) { struct rxkad_level1_hdr sechdr; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_crypt iv; struct scatterlist sg[16]; - bool aborted; u32 data_size, buf; u16 check; int ret; _enter(""); - if (len < 8) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H", - RXKADSEALEDINCON); - goto protocol_error; - } + if (sp->len < 8) + return rxrpc_abort_eproto(call, skb, RXKADSEALEDINCON, + rxkad_abort_1_short_header); /* Decrypt the skbuff in-place. TODO: We really want to decrypt * directly into the target buffer. */ sg_init_table(sg, ARRAY_SIZE(sg)); - ret = skb_to_sgvec(skb, sg, offset, 8); + ret = skb_to_sgvec(skb, sg, sp->offset, 8); if (unlikely(ret < 0)) return ret; @@ -477,12 +439,11 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, skcipher_request_zero(req); /* Extract the decrypted packet length */ - if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1", - RXKADDATALEN); - goto protocol_error; - } - len -= sizeof(sechdr); + if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) + return rxrpc_abort_eproto(call, skb, RXKADDATALEN, + rxkad_abort_1_short_encdata); + sp->offset += sizeof(sechdr); + sp->len -= sizeof(sechdr); buf = ntohl(sechdr.data_size); data_size = buf & 0xffff; @@ -490,51 +451,39 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, check = buf >> 16; check ^= seq ^ call->call_id; check &= 0xffff; - if (check != 0) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C", - RXKADSEALEDINCON); - goto protocol_error; - } - - if (data_size > len) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L", - RXKADDATALEN); - goto protocol_error; - } + if (check != 0) + return rxrpc_abort_eproto(call, skb, RXKADSEALEDINCON, + rxkad_abort_1_short_check); + if (data_size > sp->len) + return rxrpc_abort_eproto(call, skb, RXKADDATALEN, + rxkad_abort_1_short_data); + sp->len = data_size; _leave(" = 0 [dlen=%x]", data_size); return 0; - -protocol_error: - if (aborted) - rxrpc_send_abort_packet(call); - return -EPROTO; } /* * wholly decrypt a packet (level 2 security) */ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int offset, unsigned int len, rxrpc_seq_t seq, struct skcipher_request *req) { const struct rxrpc_key_token *token; struct rxkad_level2_hdr sechdr; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_crypt iv; struct scatterlist _sg[4], *sg; - bool aborted; u32 data_size, buf; u16 check; int nsg, ret; - _enter(",{%d}", skb->len); + _enter(",{%d}", sp->len); - if (len < 8) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H", - RXKADSEALEDINCON); - goto protocol_error; - } + if (sp->len < 8) + return rxrpc_abort_eproto(call, skb, RXKADSEALEDINCON, + rxkad_abort_2_short_header); /* Decrypt the skbuff in-place. TODO: We really want to decrypt * directly into the target buffer. @@ -546,11 +495,11 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, } else { sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO); if (!sg) - goto nomem; + return -ENOMEM; } sg_init_table(sg, nsg); - ret = skb_to_sgvec(skb, sg, offset, len); + ret = skb_to_sgvec(skb, sg, sp->offset, sp->len); if (unlikely(ret < 0)) { if (sg != _sg) kfree(sg); @@ -558,24 +507,23 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, } /* decrypt from the session key */ - token = call->conn->params.key->payload.data[0]; + token = call->conn->key->payload.data[0]; memcpy(&iv, token->kad->session_key, sizeof(iv)); skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, sg, sg, len, iv.x); + skcipher_request_set_crypt(req, sg, sg, sp->len, iv.x); crypto_skcipher_decrypt(req); skcipher_request_zero(req); if (sg != _sg) kfree(sg); /* Extract the decrypted packet length */ - if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2", - RXKADDATALEN); - goto protocol_error; - } - len -= sizeof(sechdr); + if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) + return rxrpc_abort_eproto(call, skb, RXKADDATALEN, + rxkad_abort_2_short_len); + sp->offset += sizeof(sechdr); + sp->len -= sizeof(sechdr); buf = ntohl(sechdr.data_size); data_size = buf & 0xffff; @@ -583,48 +531,38 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, check = buf >> 16; check ^= seq ^ call->call_id; check &= 0xffff; - if (check != 0) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C", - RXKADSEALEDINCON); - goto protocol_error; - } + if (check != 0) + return rxrpc_abort_eproto(call, skb, RXKADSEALEDINCON, + rxkad_abort_2_short_check); - if (data_size > len) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L", - RXKADDATALEN); - goto protocol_error; - } + if (data_size > sp->len) + return rxrpc_abort_eproto(call, skb, RXKADDATALEN, + rxkad_abort_2_short_data); + sp->len = data_size; _leave(" = 0 [dlen=%x]", data_size); return 0; - -protocol_error: - if (aborted) - rxrpc_send_abort_packet(call); - return -EPROTO; - -nomem: - _leave(" = -ENOMEM"); - return -ENOMEM; } /* - * Verify the security on a received packet or subpacket (if part of a - * jumbo packet). + * Verify the security on a received packet and the subpackets therein. */ -static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int offset, unsigned int len, - rxrpc_seq_t seq, u16 expected_cksum) +static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb) { + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct skcipher_request *req; struct rxrpc_crypt iv; struct scatterlist sg; - bool aborted; + union { + __be32 buf[2]; + } crypto __aligned(8); + rxrpc_seq_t seq = sp->hdr.seq; + int ret; u16 cksum; u32 x, y; _enter("{%d{%x}},{#%u}", - call->debug_id, key_serial(call->conn->params.key), seq); + call->debug_id, key_serial(call->conn->key), seq); if (!call->conn->rxkad.cipher) return 0; @@ -639,88 +577,45 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, /* validate the security checksum */ x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT); x |= seq & 0x3fffffff; - call->crypto_buf[0] = htonl(call->call_id); - call->crypto_buf[1] = htonl(x); + crypto.buf[0] = htonl(call->call_id); + crypto.buf[1] = htonl(x); - sg_init_one(&sg, call->crypto_buf, 8); + sg_init_one(&sg, crypto.buf, 8); skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); crypto_skcipher_encrypt(req); skcipher_request_zero(req); - y = ntohl(call->crypto_buf[1]); + y = ntohl(crypto.buf[1]); cksum = (y >> 16) & 0xffff; if (cksum == 0) cksum = 1; /* zero checksums are not permitted */ - if (cksum != expected_cksum) { - aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK", - RXKADSEALEDINCON); - goto protocol_error; + if (cksum != sp->hdr.cksum) { + ret = rxrpc_abort_eproto(call, skb, RXKADSEALEDINCON, + rxkad_abort_bad_checksum); + goto out; } - switch (call->conn->params.security_level) { + switch (call->conn->security_level) { case RXRPC_SECURITY_PLAIN: - return 0; + ret = 0; + break; case RXRPC_SECURITY_AUTH: - return rxkad_verify_packet_1(call, skb, offset, len, seq, req); + ret = rxkad_verify_packet_1(call, skb, seq, req); + break; case RXRPC_SECURITY_ENCRYPT: - return rxkad_verify_packet_2(call, skb, offset, len, seq, req); + ret = rxkad_verify_packet_2(call, skb, seq, req); + break; default: - return -ENOANO; + ret = -ENOANO; + break; } -protocol_error: - if (aborted) - rxrpc_send_abort_packet(call); - return -EPROTO; -} - -/* - * Locate the data contained in a packet that was partially encrypted. - */ -static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int *_offset, unsigned int *_len) -{ - struct rxkad_level1_hdr sechdr; - - if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0) - BUG(); - *_offset += sizeof(sechdr); - *_len = ntohl(sechdr.data_size) & 0xffff; -} - -/* - * Locate the data contained in a packet that was completely encrypted. - */ -static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int *_offset, unsigned int *_len) -{ - struct rxkad_level2_hdr sechdr; - - if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0) - BUG(); - *_offset += sizeof(sechdr); - *_len = ntohl(sechdr.data_size) & 0xffff; -} - -/* - * Locate the data contained in an already decrypted packet. - */ -static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb, - unsigned int *_offset, unsigned int *_len) -{ - switch (call->conn->params.security_level) { - case RXRPC_SECURITY_AUTH: - rxkad_locate_data_1(call, skb, _offset, _len); - return; - case RXRPC_SECURITY_ENCRYPT: - rxkad_locate_data_2(call, skb, _offset, _len); - return; - default: - return; - } +out: + skcipher_request_free(req); + return ret; } /* @@ -745,8 +640,8 @@ static int rxkad_issue_challenge(struct rxrpc_connection *conn) challenge.min_level = htonl(0); challenge.__padding = 0; - msg.msg_name = &conn->params.peer->srx.transport; - msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_name = &conn->peer->srx.transport; + msg.msg_namelen = conn->peer->srx.transport_len; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; @@ -771,16 +666,15 @@ static int rxkad_issue_challenge(struct rxrpc_connection *conn) serial = atomic_inc_return(&conn->serial); whdr.serial = htonl(serial); - _proto("Tx CHALLENGE %%%u", serial); - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); + ret = kernel_sendmsg(conn->local->socket, &msg, iov, 2, len); if (ret < 0) { trace_rxrpc_tx_fail(conn->debug_id, serial, ret, rxrpc_tx_point_rxkad_challenge); return -EAGAIN; } - conn->params.peer->last_tx_at = ktime_get_seconds(); + conn->peer->last_tx_at = ktime_get_seconds(); trace_rxrpc_tx_packet(conn->debug_id, &whdr, rxrpc_tx_point_rxkad_challenge); _leave(" = 0"); @@ -804,8 +698,8 @@ static int rxkad_send_response(struct rxrpc_connection *conn, _enter(""); - msg.msg_name = &conn->params.peer->srx.transport; - msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_name = &conn->peer->srx.transport; + msg.msg_namelen = conn->peer->srx.transport_len; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; @@ -829,16 +723,15 @@ static int rxkad_send_response(struct rxrpc_connection *conn, serial = atomic_inc_return(&conn->serial); whdr.serial = htonl(serial); - _proto("Tx RESPONSE %%%u", serial); - ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len); + ret = kernel_sendmsg(conn->local->socket, &msg, iov, 3, len); if (ret < 0) { trace_rxrpc_tx_fail(conn->debug_id, serial, ret, rxrpc_tx_point_rxkad_response); return -EAGAIN; } - conn->params.peer->last_tx_at = ktime_get_seconds(); + conn->peer->last_tx_at = ktime_get_seconds(); _leave(" = 0"); return 0; } @@ -890,53 +783,46 @@ static int rxkad_encrypt_response(struct rxrpc_connection *conn, * respond to a challenge packet */ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, - struct sk_buff *skb, - u32 *_abort_code) + struct sk_buff *skb) { const struct rxrpc_key_token *token; struct rxkad_challenge challenge; struct rxkad_response *resp; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - const char *eproto; - u32 version, nonce, min_level, abort_code; - int ret; + u32 version, nonce, min_level; + int ret = -EPROTO; - _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key)); + _enter("{%d,%x}", conn->debug_id, key_serial(conn->key)); - eproto = tracepoint_string("chall_no_key"); - abort_code = RX_PROTOCOL_ERROR; - if (!conn->params.key) - goto protocol_error; + if (!conn->key) + return rxrpc_abort_conn(conn, skb, RX_PROTOCOL_ERROR, -EPROTO, + rxkad_abort_chall_no_key); - abort_code = RXKADEXPIRED; - ret = key_validate(conn->params.key); + ret = key_validate(conn->key); if (ret < 0) - goto other_error; + return rxrpc_abort_conn(conn, skb, RXKADEXPIRED, ret, + rxkad_abort_chall_key_expired); - eproto = tracepoint_string("chall_short"); - abort_code = RXKADPACKETSHORT; if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), &challenge, sizeof(challenge)) < 0) - goto protocol_error; + return rxrpc_abort_conn(conn, skb, RXKADPACKETSHORT, -EPROTO, + rxkad_abort_chall_short); version = ntohl(challenge.version); nonce = ntohl(challenge.nonce); min_level = ntohl(challenge.min_level); - _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }", - sp->hdr.serial, version, nonce, min_level); + trace_rxrpc_rx_challenge(conn, sp->hdr.serial, version, nonce, min_level); - eproto = tracepoint_string("chall_ver"); - abort_code = RXKADINCONSISTENCY; if (version != RXKAD_VERSION) - goto protocol_error; + return rxrpc_abort_conn(conn, skb, RXKADINCONSISTENCY, -EPROTO, + rxkad_abort_chall_version); - abort_code = RXKADLEVELFAIL; - ret = -EACCES; - if (conn->params.security_level < min_level) - goto other_error; + if (conn->security_level < min_level) + return rxrpc_abort_conn(conn, skb, RXKADLEVELFAIL, -EACCES, + rxkad_abort_chall_level); - token = conn->params.key->payload.data[0]; + token = conn->key->payload.data[0]; /* build the response packet */ resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS); @@ -948,7 +834,7 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, resp->encrypted.cid = htonl(conn->proto.cid); resp->encrypted.securityIndex = htonl(conn->security_ix); resp->encrypted.inc_nonce = htonl(nonce + 1); - resp->encrypted.level = htonl(conn->params.security_level); + resp->encrypted.level = htonl(conn->security_level); resp->kvno = htonl(token->kad->kvno); resp->ticket_len = htonl(token->kad->ticket_len); resp->encrypted.call_id[0] = htonl(conn->channels[0].call_counter); @@ -963,13 +849,6 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad); kfree(resp); return ret; - -protocol_error: - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); - ret = -EPROTO; -other_error: - *_abort_code = abort_code; - return ret; } /* @@ -980,20 +859,15 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, struct sk_buff *skb, void *ticket, size_t ticket_len, struct rxrpc_crypt *_session_key, - time64_t *_expiry, - u32 *_abort_code) + time64_t *_expiry) { struct skcipher_request *req; - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_crypt iv, key; struct scatterlist sg[1]; struct in_addr addr; unsigned int life; - const char *eproto; time64_t issue, now; bool little_endian; - int ret; - u32 abort_code; u8 *p, *q, *name, *end; _enter("{%d},{%x}", conn->debug_id, key_serial(server_key)); @@ -1005,10 +879,9 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, memcpy(&iv, &server_key->payload.data[2], sizeof(iv)); - ret = -ENOMEM; req = skcipher_request_alloc(server_key->payload.data[0], GFP_NOFS); if (!req) - goto temporary_error; + return -ENOMEM; sg_init_one(&sg[0], ticket, ticket_len); skcipher_request_set_callback(req, 0, NULL, NULL); @@ -1019,18 +892,21 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, p = ticket; end = p + ticket_len; -#define Z(field) \ - ({ \ - u8 *__str = p; \ - eproto = tracepoint_string("rxkad_bad_"#field); \ - q = memchr(p, 0, end - p); \ - if (!q || q - p > (field##_SZ)) \ - goto bad_ticket; \ - for (; p < q; p++) \ - if (!isprint(*p)) \ - goto bad_ticket; \ - p++; \ - __str; \ +#define Z(field, fieldl) \ + ({ \ + u8 *__str = p; \ + q = memchr(p, 0, end - p); \ + if (!q || q - p > field##_SZ) \ + return rxrpc_abort_conn( \ + conn, skb, RXKADBADTICKET, -EPROTO, \ + rxkad_abort_resp_tkt_##fieldl); \ + for (; p < q; p++) \ + if (!isprint(*p)) \ + return rxrpc_abort_conn( \ + conn, skb, RXKADBADTICKET, -EPROTO, \ + rxkad_abort_resp_tkt_##fieldl); \ + p++; \ + __str; \ }) /* extract the ticket flags */ @@ -1039,20 +915,20 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, p++; /* extract the authentication name */ - name = Z(ANAME); + name = Z(ANAME, aname); _debug("KIV ANAME: %s", name); /* extract the principal's instance */ - name = Z(INST); + name = Z(INST, inst); _debug("KIV INST : %s", name); /* extract the principal's authentication domain */ - name = Z(REALM); + name = Z(REALM, realm); _debug("KIV REALM: %s", name); - eproto = tracepoint_string("rxkad_bad_len"); if (end - p < 4 + 8 + 4 + 2) - goto bad_ticket; + return rxrpc_abort_conn(conn, skb, RXKADBADTICKET, -EPROTO, + rxkad_abort_resp_tkt_short); /* get the IPv4 address of the entity that requested the ticket */ memcpy(&addr, p, sizeof(addr)); @@ -1084,38 +960,23 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, _debug("KIV ISSUE: %llx [%llx]", issue, now); /* check the ticket is in date */ - if (issue > now) { - abort_code = RXKADNOAUTH; - ret = -EKEYREJECTED; - goto other_error; - } - - if (issue < now - life) { - abort_code = RXKADEXPIRED; - ret = -EKEYEXPIRED; - goto other_error; - } + if (issue > now) + return rxrpc_abort_conn(conn, skb, RXKADNOAUTH, -EKEYREJECTED, + rxkad_abort_resp_tkt_future); + if (issue < now - life) + return rxrpc_abort_conn(conn, skb, RXKADEXPIRED, -EKEYEXPIRED, + rxkad_abort_resp_tkt_expired); *_expiry = issue + life; /* get the service name */ - name = Z(SNAME); + name = Z(SNAME, sname); _debug("KIV SNAME: %s", name); /* get the service instance name */ - name = Z(INST); + name = Z(INST, sinst); _debug("KIV SINST: %s", name); return 0; - -bad_ticket: - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); - abort_code = RXKADBADTICKET; - ret = -EPROTO; -other_error: - *_abort_code = abort_code; - return ret; -temporary_error: - return ret; } /* @@ -1156,17 +1017,15 @@ static void rxkad_decrypt_response(struct rxrpc_connection *conn, * verify a response */ static int rxkad_verify_response(struct rxrpc_connection *conn, - struct sk_buff *skb, - u32 *_abort_code) + struct sk_buff *skb) { struct rxkad_response *response; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_crypt session_key; struct key *server_key; - const char *eproto; time64_t expiry; void *ticket; - u32 abort_code, version, kvno, ticket_len, level; + u32 version, kvno, ticket_len, level; __be32 csum; int ret, i; @@ -1174,22 +1033,18 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, server_key = rxrpc_look_up_server_security(conn, skb, 0, 0); if (IS_ERR(server_key)) { - switch (PTR_ERR(server_key)) { + ret = PTR_ERR(server_key); + switch (ret) { case -ENOKEY: - abort_code = RXKADUNKNOWNKEY; - break; + return rxrpc_abort_conn(conn, skb, RXKADUNKNOWNKEY, ret, + rxkad_abort_resp_nokey); case -EKEYEXPIRED: - abort_code = RXKADEXPIRED; - break; + return rxrpc_abort_conn(conn, skb, RXKADEXPIRED, ret, + rxkad_abort_resp_key_expired); default: - abort_code = RXKADNOAUTH; - break; + return rxrpc_abort_conn(conn, skb, RXKADNOAUTH, ret, + rxkad_abort_resp_key_rejected); } - trace_rxrpc_abort(0, "SVK", - sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - abort_code, PTR_ERR(server_key)); - *_abort_code = abort_code; - return -EPROTO; } ret = -ENOMEM; @@ -1197,32 +1052,36 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, if (!response) goto temporary_error; - eproto = tracepoint_string("rxkad_rsp_short"); - abort_code = RXKADPACKETSHORT; if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), - response, sizeof(*response)) < 0) + response, sizeof(*response)) < 0) { + rxrpc_abort_conn(conn, skb, RXKADPACKETSHORT, -EPROTO, + rxkad_abort_resp_short); goto protocol_error; + } version = ntohl(response->version); ticket_len = ntohl(response->ticket_len); kvno = ntohl(response->kvno); - _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }", - sp->hdr.serial, version, kvno, ticket_len); - eproto = tracepoint_string("rxkad_rsp_ver"); - abort_code = RXKADINCONSISTENCY; - if (version != RXKAD_VERSION) + trace_rxrpc_rx_response(conn, sp->hdr.serial, version, kvno, ticket_len); + + if (version != RXKAD_VERSION) { + rxrpc_abort_conn(conn, skb, RXKADINCONSISTENCY, -EPROTO, + rxkad_abort_resp_version); goto protocol_error; + } - eproto = tracepoint_string("rxkad_rsp_tktlen"); - abort_code = RXKADTICKETLEN; - if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN) + if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN) { + rxrpc_abort_conn(conn, skb, RXKADTICKETLEN, -EPROTO, + rxkad_abort_resp_tkt_len); goto protocol_error; + } - eproto = tracepoint_string("rxkad_rsp_unkkey"); - abort_code = RXKADUNKNOWNKEY; - if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5) + if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5) { + rxrpc_abort_conn(conn, skb, RXKADUNKNOWNKEY, -EPROTO, + rxkad_abort_resp_unknown_tkt); goto protocol_error; + } /* extract the kerberos ticket and decrypt and decode it */ ret = -ENOMEM; @@ -1230,14 +1089,15 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, if (!ticket) goto temporary_error_free_resp; - eproto = tracepoint_string("rxkad_tkt_short"); - abort_code = RXKADPACKETSHORT; if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header) + sizeof(*response), - ticket, ticket_len) < 0) - goto protocol_error_free; + ticket, ticket_len) < 0) { + rxrpc_abort_conn(conn, skb, RXKADPACKETSHORT, -EPROTO, + rxkad_abort_resp_short_tkt); + goto protocol_error; + } ret = rxkad_decrypt_ticket(conn, server_key, skb, ticket, ticket_len, - &session_key, &expiry, _abort_code); + &session_key, &expiry); if (ret < 0) goto temporary_error_free_ticket; @@ -1245,57 +1105,62 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, * response */ rxkad_decrypt_response(conn, response, &session_key); - eproto = tracepoint_string("rxkad_rsp_param"); - abort_code = RXKADSEALEDINCON; - if (ntohl(response->encrypted.epoch) != conn->proto.epoch) - goto protocol_error_free; - if (ntohl(response->encrypted.cid) != conn->proto.cid) - goto protocol_error_free; - if (ntohl(response->encrypted.securityIndex) != conn->security_ix) + if (ntohl(response->encrypted.epoch) != conn->proto.epoch || + ntohl(response->encrypted.cid) != conn->proto.cid || + ntohl(response->encrypted.securityIndex) != conn->security_ix) { + rxrpc_abort_conn(conn, skb, RXKADSEALEDINCON, -EPROTO, + rxkad_abort_resp_bad_param); goto protocol_error_free; + } + csum = response->encrypted.checksum; response->encrypted.checksum = 0; rxkad_calc_response_checksum(response); - eproto = tracepoint_string("rxkad_rsp_csum"); - if (response->encrypted.checksum != csum) + if (response->encrypted.checksum != csum) { + rxrpc_abort_conn(conn, skb, RXKADSEALEDINCON, -EPROTO, + rxkad_abort_resp_bad_checksum); goto protocol_error_free; + } - spin_lock(&conn->bundle->channel_lock); for (i = 0; i < RXRPC_MAXCALLS; i++) { - struct rxrpc_call *call; u32 call_id = ntohl(response->encrypted.call_id[i]); + u32 counter = READ_ONCE(conn->channels[i].call_counter); + + if (call_id > INT_MAX) { + rxrpc_abort_conn(conn, skb, RXKADSEALEDINCON, -EPROTO, + rxkad_abort_resp_bad_callid); + goto protocol_error_free; + } - eproto = tracepoint_string("rxkad_rsp_callid"); - if (call_id > INT_MAX) - goto protocol_error_unlock; - - eproto = tracepoint_string("rxkad_rsp_callctr"); - if (call_id < conn->channels[i].call_counter) - goto protocol_error_unlock; - - eproto = tracepoint_string("rxkad_rsp_callst"); - if (call_id > conn->channels[i].call_counter) { - call = rcu_dereference_protected( - conn->channels[i].call, - lockdep_is_held(&conn->bundle->channel_lock)); - if (call && call->state < RXRPC_CALL_COMPLETE) - goto protocol_error_unlock; + if (call_id < counter) { + rxrpc_abort_conn(conn, skb, RXKADSEALEDINCON, -EPROTO, + rxkad_abort_resp_call_ctr); + goto protocol_error_free; + } + + if (call_id > counter) { + if (conn->channels[i].call) { + rxrpc_abort_conn(conn, skb, RXKADSEALEDINCON, -EPROTO, + rxkad_abort_resp_call_state); + goto protocol_error_free; + } conn->channels[i].call_counter = call_id; } } - spin_unlock(&conn->bundle->channel_lock); - eproto = tracepoint_string("rxkad_rsp_seq"); - abort_code = RXKADOUTOFSEQUENCE; - if (ntohl(response->encrypted.inc_nonce) != conn->rxkad.nonce + 1) + if (ntohl(response->encrypted.inc_nonce) != conn->rxkad.nonce + 1) { + rxrpc_abort_conn(conn, skb, RXKADOUTOFSEQUENCE, -EPROTO, + rxkad_abort_resp_ooseq); goto protocol_error_free; + } - eproto = tracepoint_string("rxkad_rsp_level"); - abort_code = RXKADLEVELFAIL; level = ntohl(response->encrypted.level); - if (level > RXRPC_SECURITY_ENCRYPT) + if (level > RXRPC_SECURITY_ENCRYPT) { + rxrpc_abort_conn(conn, skb, RXKADLEVELFAIL, -EPROTO, + rxkad_abort_resp_level); goto protocol_error_free; - conn->params.security_level = level; + } + conn->security_level = level; /* create a key to hold the security data and expiration time - after * this the connection security can be handled in exactly the same way @@ -1309,15 +1174,11 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, _leave(" = 0"); return 0; -protocol_error_unlock: - spin_unlock(&conn->bundle->channel_lock); protocol_error_free: kfree(ticket); protocol_error: kfree(response); - trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); key_put(server_key); - *_abort_code = abort_code; return -EPROTO; temporary_error_free_ticket: @@ -1397,7 +1258,6 @@ const struct rxrpc_security rxkad = { .secure_packet = rxkad_secure_packet, .verify_packet = rxkad_verify_packet, .free_call_crypto = rxkad_free_call_crypto, - .locate_data = rxkad_locate_data, .issue_challenge = rxkad_issue_challenge, .respond_to_challenge = rxkad_respond_to_challenge, .verify_response = rxkad_verify_response, diff --git a/net/rxrpc/rxperf.c b/net/rxrpc/rxperf.c new file mode 100644 index 000000000000..16dcabb71ebe --- /dev/null +++ b/net/rxrpc/rxperf.c @@ -0,0 +1,626 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* In-kernel rxperf server for testing purposes. + * + * Copyright (C) 2022 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) "rxperf: " fmt +#include <linux/module.h> +#include <linux/slab.h> +#include <net/sock.h> +#include <net/af_rxrpc.h> +#define RXRPC_TRACE_ONLY_DEFINE_ENUMS +#include <trace/events/rxrpc.h> + +MODULE_DESCRIPTION("rxperf test server (afs)"); +MODULE_AUTHOR("Red Hat, Inc."); +MODULE_LICENSE("GPL"); + +#define RXPERF_PORT 7009 +#define RX_PERF_SERVICE 147 +#define RX_PERF_VERSION 3 +#define RX_PERF_SEND 0 +#define RX_PERF_RECV 1 +#define RX_PERF_RPC 3 +#define RX_PERF_FILE 4 +#define RX_PERF_MAGIC_COOKIE 0x4711 + +struct rxperf_proto_params { + __be32 version; + __be32 type; + __be32 rsize; + __be32 wsize; +} __packed; + +static const u8 rxperf_magic_cookie[] = { 0x00, 0x00, 0x47, 0x11 }; +static const u8 secret[8] = { 0xa7, 0x83, 0x8a, 0xcb, 0xc7, 0x83, 0xec, 0x94 }; + +enum rxperf_call_state { + RXPERF_CALL_SV_AWAIT_PARAMS, /* Server: Awaiting parameter block */ + RXPERF_CALL_SV_AWAIT_REQUEST, /* Server: Awaiting request data */ + RXPERF_CALL_SV_REPLYING, /* Server: Replying */ + RXPERF_CALL_SV_AWAIT_ACK, /* Server: Awaiting final ACK */ + RXPERF_CALL_COMPLETE, /* Completed or failed */ +}; + +struct rxperf_call { + struct rxrpc_call *rxcall; + struct iov_iter iter; + struct kvec kvec[1]; + struct work_struct work; + const char *type; + size_t iov_len; + size_t req_len; /* Size of request blob */ + size_t reply_len; /* Size of reply blob */ + unsigned int debug_id; + unsigned int operation_id; + struct rxperf_proto_params params; + __be32 tmp[2]; + s32 abort_code; + enum rxperf_call_state state; + short error; + unsigned short unmarshal; + u16 service_id; + int (*deliver)(struct rxperf_call *call); + void (*processor)(struct work_struct *work); +}; + +static struct socket *rxperf_socket; +static struct key *rxperf_sec_keyring; /* Ring of security/crypto keys */ +static struct workqueue_struct *rxperf_workqueue; + +static void rxperf_deliver_to_call(struct work_struct *work); +static int rxperf_deliver_param_block(struct rxperf_call *call); +static int rxperf_deliver_request(struct rxperf_call *call); +static int rxperf_process_call(struct rxperf_call *call); +static void rxperf_charge_preallocation(struct work_struct *work); + +static DECLARE_WORK(rxperf_charge_preallocation_work, + rxperf_charge_preallocation); + +static inline void rxperf_set_call_state(struct rxperf_call *call, + enum rxperf_call_state to) +{ + call->state = to; +} + +static inline void rxperf_set_call_complete(struct rxperf_call *call, + int error, s32 remote_abort) +{ + if (call->state != RXPERF_CALL_COMPLETE) { + call->abort_code = remote_abort; + call->error = error; + call->state = RXPERF_CALL_COMPLETE; + } +} + +static void rxperf_rx_discard_new_call(struct rxrpc_call *rxcall, + unsigned long user_call_ID) +{ + kfree((struct rxperf_call *)user_call_ID); +} + +static void rxperf_rx_new_call(struct sock *sk, struct rxrpc_call *rxcall, + unsigned long user_call_ID) +{ + queue_work(rxperf_workqueue, &rxperf_charge_preallocation_work); +} + +static void rxperf_queue_call_work(struct rxperf_call *call) +{ + queue_work(rxperf_workqueue, &call->work); +} + +static void rxperf_notify_rx(struct sock *sk, struct rxrpc_call *rxcall, + unsigned long call_user_ID) +{ + struct rxperf_call *call = (struct rxperf_call *)call_user_ID; + + if (call->state != RXPERF_CALL_COMPLETE) + rxperf_queue_call_work(call); +} + +static void rxperf_rx_attach(struct rxrpc_call *rxcall, unsigned long user_call_ID) +{ + struct rxperf_call *call = (struct rxperf_call *)user_call_ID; + + call->rxcall = rxcall; +} + +static void rxperf_notify_end_reply_tx(struct sock *sock, + struct rxrpc_call *rxcall, + unsigned long call_user_ID) +{ + rxperf_set_call_state((struct rxperf_call *)call_user_ID, + RXPERF_CALL_SV_AWAIT_ACK); +} + +/* + * Charge the incoming call preallocation. + */ +static void rxperf_charge_preallocation(struct work_struct *work) +{ + struct rxperf_call *call; + + for (;;) { + call = kzalloc(sizeof(*call), GFP_KERNEL); + if (!call) + break; + + call->type = "unset"; + call->debug_id = atomic_inc_return(&rxrpc_debug_id); + call->deliver = rxperf_deliver_param_block; + call->state = RXPERF_CALL_SV_AWAIT_PARAMS; + call->service_id = RX_PERF_SERVICE; + call->iov_len = sizeof(call->params); + call->kvec[0].iov_len = sizeof(call->params); + call->kvec[0].iov_base = &call->params; + iov_iter_kvec(&call->iter, READ, call->kvec, 1, call->iov_len); + INIT_WORK(&call->work, rxperf_deliver_to_call); + + if (rxrpc_kernel_charge_accept(rxperf_socket, + rxperf_notify_rx, + rxperf_rx_attach, + (unsigned long)call, + GFP_KERNEL, + call->debug_id) < 0) + break; + call = NULL; + } + + kfree(call); +} + +/* + * Open an rxrpc socket and bind it to be a server for callback notifications + * - the socket is left in blocking mode and non-blocking ops use MSG_DONTWAIT + */ +static int rxperf_open_socket(void) +{ + struct sockaddr_rxrpc srx; + struct socket *socket; + int ret; + + ret = sock_create_kern(&init_net, AF_RXRPC, SOCK_DGRAM, PF_INET6, + &socket); + if (ret < 0) + goto error_1; + + socket->sk->sk_allocation = GFP_NOFS; + + /* bind the callback manager's address to make this a server socket */ + memset(&srx, 0, sizeof(srx)); + srx.srx_family = AF_RXRPC; + srx.srx_service = RX_PERF_SERVICE; + srx.transport_type = SOCK_DGRAM; + srx.transport_len = sizeof(srx.transport.sin6); + srx.transport.sin6.sin6_family = AF_INET6; + srx.transport.sin6.sin6_port = htons(RXPERF_PORT); + + ret = rxrpc_sock_set_min_security_level(socket->sk, + RXRPC_SECURITY_ENCRYPT); + if (ret < 0) + goto error_2; + + ret = rxrpc_sock_set_security_keyring(socket->sk, rxperf_sec_keyring); + + ret = kernel_bind(socket, (struct sockaddr *)&srx, sizeof(srx)); + if (ret < 0) + goto error_2; + + rxrpc_kernel_new_call_notification(socket, rxperf_rx_new_call, + rxperf_rx_discard_new_call); + + ret = kernel_listen(socket, INT_MAX); + if (ret < 0) + goto error_2; + + rxperf_socket = socket; + rxperf_charge_preallocation(&rxperf_charge_preallocation_work); + return 0; + +error_2: + sock_release(socket); +error_1: + pr_err("Can't set up rxperf socket: %d\n", ret); + return ret; +} + +/* + * close the rxrpc socket rxperf was using + */ +static void rxperf_close_socket(void) +{ + kernel_listen(rxperf_socket, 0); + kernel_sock_shutdown(rxperf_socket, SHUT_RDWR); + flush_workqueue(rxperf_workqueue); + sock_release(rxperf_socket); +} + +/* + * Log remote abort codes that indicate that we have a protocol disagreement + * with the server. + */ +static void rxperf_log_error(struct rxperf_call *call, s32 remote_abort) +{ + static int max = 0; + const char *msg; + int m; + + switch (remote_abort) { + case RX_EOF: msg = "unexpected EOF"; break; + case RXGEN_CC_MARSHAL: msg = "client marshalling"; break; + case RXGEN_CC_UNMARSHAL: msg = "client unmarshalling"; break; + case RXGEN_SS_MARSHAL: msg = "server marshalling"; break; + case RXGEN_SS_UNMARSHAL: msg = "server unmarshalling"; break; + case RXGEN_DECODE: msg = "opcode decode"; break; + case RXGEN_SS_XDRFREE: msg = "server XDR cleanup"; break; + case RXGEN_CC_XDRFREE: msg = "client XDR cleanup"; break; + case -32: msg = "insufficient data"; break; + default: + return; + } + + m = max; + if (m < 3) { + max = m + 1; + pr_info("Peer reported %s failure on %s\n", msg, call->type); + } +} + +/* + * deliver messages to a call + */ +static void rxperf_deliver_to_call(struct work_struct *work) +{ + struct rxperf_call *call = container_of(work, struct rxperf_call, work); + enum rxperf_call_state state; + u32 abort_code, remote_abort = 0; + int ret = 0; + + if (call->state == RXPERF_CALL_COMPLETE) + return; + + while (state = call->state, + state == RXPERF_CALL_SV_AWAIT_PARAMS || + state == RXPERF_CALL_SV_AWAIT_REQUEST || + state == RXPERF_CALL_SV_AWAIT_ACK + ) { + if (state == RXPERF_CALL_SV_AWAIT_ACK) { + if (!rxrpc_kernel_check_life(rxperf_socket, call->rxcall)) + goto call_complete; + return; + } + + ret = call->deliver(call); + if (ret == 0) + ret = rxperf_process_call(call); + + switch (ret) { + case 0: + continue; + case -EINPROGRESS: + case -EAGAIN: + return; + case -ECONNABORTED: + rxperf_log_error(call, call->abort_code); + goto call_complete; + case -EOPNOTSUPP: + abort_code = RXGEN_OPCODE; + rxrpc_kernel_abort_call(rxperf_socket, call->rxcall, + abort_code, ret, + rxperf_abort_op_not_supported); + goto call_complete; + case -ENOTSUPP: + abort_code = RX_USER_ABORT; + rxrpc_kernel_abort_call(rxperf_socket, call->rxcall, + abort_code, ret, + rxperf_abort_op_not_supported); + goto call_complete; + case -EIO: + pr_err("Call %u in bad state %u\n", + call->debug_id, call->state); + fallthrough; + case -ENODATA: + case -EBADMSG: + case -EMSGSIZE: + case -ENOMEM: + case -EFAULT: + rxrpc_kernel_abort_call(rxperf_socket, call->rxcall, + RXGEN_SS_UNMARSHAL, ret, + rxperf_abort_unmarshal_error); + goto call_complete; + default: + rxrpc_kernel_abort_call(rxperf_socket, call->rxcall, + RX_CALL_DEAD, ret, + rxperf_abort_general_error); + goto call_complete; + } + } + +call_complete: + rxperf_set_call_complete(call, ret, remote_abort); + /* The call may have been requeued */ + rxrpc_kernel_end_call(rxperf_socket, call->rxcall); + cancel_work(&call->work); + kfree(call); +} + +/* + * Extract a piece of data from the received data socket buffers. + */ +static int rxperf_extract_data(struct rxperf_call *call, bool want_more) +{ + u32 remote_abort = 0; + int ret; + + ret = rxrpc_kernel_recv_data(rxperf_socket, call->rxcall, &call->iter, + &call->iov_len, want_more, &remote_abort, + &call->service_id); + pr_debug("Extract i=%zu l=%zu m=%u ret=%d\n", + iov_iter_count(&call->iter), call->iov_len, want_more, ret); + if (ret == 0 || ret == -EAGAIN) + return ret; + + if (ret == 1) { + switch (call->state) { + case RXPERF_CALL_SV_AWAIT_REQUEST: + rxperf_set_call_state(call, RXPERF_CALL_SV_REPLYING); + break; + case RXPERF_CALL_COMPLETE: + pr_debug("premature completion %d", call->error); + return call->error; + default: + break; + } + return 0; + } + + rxperf_set_call_complete(call, ret, remote_abort); + return ret; +} + +/* + * Grab the operation ID from an incoming manager call. + */ +static int rxperf_deliver_param_block(struct rxperf_call *call) +{ + u32 version; + int ret; + + /* Extract the parameter block */ + ret = rxperf_extract_data(call, true); + if (ret < 0) + return ret; + + version = ntohl(call->params.version); + call->operation_id = ntohl(call->params.type); + call->deliver = rxperf_deliver_request; + + if (version != RX_PERF_VERSION) { + pr_info("Version mismatch %x\n", version); + return -ENOTSUPP; + } + + switch (call->operation_id) { + case RX_PERF_SEND: + call->type = "send"; + call->reply_len = 0; + call->iov_len = 4; /* Expect req size */ + break; + case RX_PERF_RECV: + call->type = "recv"; + call->req_len = 0; + call->iov_len = 4; /* Expect reply size */ + break; + case RX_PERF_RPC: + call->type = "rpc"; + call->iov_len = 8; /* Expect req size and reply size */ + break; + case RX_PERF_FILE: + call->type = "file"; + fallthrough; + default: + return -EOPNOTSUPP; + } + + rxperf_set_call_state(call, RXPERF_CALL_SV_AWAIT_REQUEST); + return call->deliver(call); +} + +/* + * Deliver the request data. + */ +static int rxperf_deliver_request(struct rxperf_call *call) +{ + int ret; + + switch (call->unmarshal) { + case 0: + call->kvec[0].iov_len = call->iov_len; + call->kvec[0].iov_base = call->tmp; + iov_iter_kvec(&call->iter, READ, call->kvec, 1, call->iov_len); + call->unmarshal++; + fallthrough; + case 1: + ret = rxperf_extract_data(call, true); + if (ret < 0) + return ret; + + switch (call->operation_id) { + case RX_PERF_SEND: + call->type = "send"; + call->req_len = ntohl(call->tmp[0]); + call->reply_len = 0; + break; + case RX_PERF_RECV: + call->type = "recv"; + call->req_len = 0; + call->reply_len = ntohl(call->tmp[0]); + break; + case RX_PERF_RPC: + call->type = "rpc"; + call->req_len = ntohl(call->tmp[0]); + call->reply_len = ntohl(call->tmp[1]); + break; + default: + pr_info("Can't parse extra params\n"); + return -EIO; + } + + pr_debug("CALL op=%s rq=%zx rp=%zx\n", + call->type, call->req_len, call->reply_len); + + call->iov_len = call->req_len; + iov_iter_discard(&call->iter, READ, call->req_len); + call->unmarshal++; + fallthrough; + case 2: + ret = rxperf_extract_data(call, false); + if (ret < 0) + return ret; + call->unmarshal++; + fallthrough; + default: + return 0; + } +} + +/* + * Process a call for which we've received the request. + */ +static int rxperf_process_call(struct rxperf_call *call) +{ + struct msghdr msg = {}; + struct bio_vec bv[1]; + struct kvec iov[1]; + ssize_t n; + size_t reply_len = call->reply_len, len; + + rxrpc_kernel_set_tx_length(rxperf_socket, call->rxcall, + reply_len + sizeof(rxperf_magic_cookie)); + + while (reply_len > 0) { + len = min_t(size_t, reply_len, PAGE_SIZE); + bv[0].bv_page = ZERO_PAGE(0); + bv[0].bv_offset = 0; + bv[0].bv_len = len; + iov_iter_bvec(&msg.msg_iter, WRITE, bv, 1, len); + msg.msg_flags = MSG_MORE; + n = rxrpc_kernel_send_data(rxperf_socket, call->rxcall, &msg, + len, rxperf_notify_end_reply_tx); + if (n < 0) + return n; + if (n == 0) + return -EIO; + reply_len -= n; + } + + len = sizeof(rxperf_magic_cookie); + iov[0].iov_base = (void *)rxperf_magic_cookie; + iov[0].iov_len = len; + iov_iter_kvec(&msg.msg_iter, WRITE, iov, 1, len); + msg.msg_flags = 0; + n = rxrpc_kernel_send_data(rxperf_socket, call->rxcall, &msg, len, + rxperf_notify_end_reply_tx); + if (n >= 0) + return 0; /* Success */ + + if (n == -ENOMEM) + rxrpc_kernel_abort_call(rxperf_socket, call->rxcall, + RXGEN_SS_MARSHAL, -ENOMEM, + rxperf_abort_oom); + return n; +} + +/* + * Add a key to the security keyring. + */ +static int rxperf_add_key(struct key *keyring) +{ + key_ref_t kref; + int ret; + + kref = key_create_or_update(make_key_ref(keyring, true), + "rxrpc_s", + __stringify(RX_PERF_SERVICE) ":2", + secret, + sizeof(secret), + KEY_POS_VIEW | KEY_POS_READ | KEY_POS_SEARCH + | KEY_USR_VIEW, + KEY_ALLOC_NOT_IN_QUOTA); + + if (IS_ERR(kref)) { + pr_err("Can't allocate rxperf server key: %ld\n", PTR_ERR(kref)); + return PTR_ERR(kref); + } + + ret = key_link(keyring, key_ref_to_ptr(kref)); + if (ret < 0) + pr_err("Can't link rxperf server key: %d\n", ret); + key_ref_put(kref); + return ret; +} + +/* + * Initialise the rxperf server. + */ +static int __init rxperf_init(void) +{ + struct key *keyring; + int ret = -ENOMEM; + + pr_info("Server registering\n"); + + rxperf_workqueue = alloc_workqueue("rxperf", 0, 0); + if (!rxperf_workqueue) + goto error_workqueue; + + keyring = keyring_alloc("rxperf_server", + GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, current_cred(), + KEY_POS_VIEW | KEY_POS_READ | KEY_POS_SEARCH | + KEY_POS_WRITE | + KEY_USR_VIEW | KEY_USR_READ | KEY_USR_SEARCH | + KEY_USR_WRITE | + KEY_OTH_VIEW | KEY_OTH_READ | KEY_OTH_SEARCH, + KEY_ALLOC_NOT_IN_QUOTA, + NULL, NULL); + if (IS_ERR(keyring)) { + pr_err("Can't allocate rxperf server keyring: %ld\n", + PTR_ERR(keyring)); + goto error_keyring; + } + rxperf_sec_keyring = keyring; + ret = rxperf_add_key(keyring); + if (ret < 0) + goto error_key; + + ret = rxperf_open_socket(); + if (ret < 0) + goto error_socket; + return 0; + +error_socket: +error_key: + key_put(rxperf_sec_keyring); +error_keyring: + destroy_workqueue(rxperf_workqueue); + rcu_barrier(); +error_workqueue: + pr_err("Failed to register: %d\n", ret); + return ret; +} +late_initcall(rxperf_init); /* Must be called after net/ to create socket */ + +static void __exit rxperf_exit(void) +{ + pr_info("Server unregistering.\n"); + + rxperf_close_socket(); + key_put(rxperf_sec_keyring); + destroy_workqueue(rxperf_workqueue); + rcu_barrier(); +} +module_exit(rxperf_exit); + diff --git a/net/rxrpc/security.c b/net/rxrpc/security.c index 50cb5f1ee0c0..cb8dd1d3b1d4 100644 --- a/net/rxrpc/security.c +++ b/net/rxrpc/security.c @@ -63,19 +63,17 @@ const struct rxrpc_security *rxrpc_security_lookup(u8 security_index) } /* - * initialise the security on a client connection + * Initialise the security on a client call. */ -int rxrpc_init_client_conn_security(struct rxrpc_connection *conn) +int rxrpc_init_client_call_security(struct rxrpc_call *call) { - const struct rxrpc_security *sec; + const struct rxrpc_security *sec = &rxrpc_no_security; struct rxrpc_key_token *token; - struct key *key = conn->params.key; + struct key *key = call->key; int ret; - _enter("{%d},{%x}", conn->debug_id, key_serial(key)); - if (!key) - return 0; + goto found; ret = key_validate(key); if (ret < 0) @@ -89,16 +87,41 @@ int rxrpc_init_client_conn_security(struct rxrpc_connection *conn) return -EKEYREJECTED; found: - conn->security = sec; + call->security = sec; + call->security_ix = sec->security_index; + return 0; +} - ret = conn->security->init_connection_security(conn, token); - if (ret < 0) { - conn->security = &rxrpc_no_security; - return ret; +/* + * initialise the security on a client connection + */ +int rxrpc_init_client_conn_security(struct rxrpc_connection *conn) +{ + struct rxrpc_key_token *token; + struct key *key = conn->key; + int ret = 0; + + _enter("{%d},{%x}", conn->debug_id, key_serial(key)); + + for (token = key->payload.data[0]; token; token = token->next) { + if (token->security_index == conn->security->security_index) + goto found; } + return -EKEYREJECTED; - _leave(" = 0"); - return 0; +found: + mutex_lock(&conn->security_lock); + if (conn->state == RXRPC_CONN_CLIENT_UNSECURED) { + ret = conn->security->init_connection_security(conn, token); + if (ret == 0) { + spin_lock(&conn->state_lock); + if (conn->state == RXRPC_CONN_CLIENT_UNSECURED) + conn->state = RXRPC_CONN_CLIENT; + spin_unlock(&conn->state_lock); + } + } + mutex_unlock(&conn->security_lock); + return ret; } /* @@ -114,21 +137,15 @@ const struct rxrpc_security *rxrpc_get_incoming_security(struct rxrpc_sock *rx, sec = rxrpc_security_lookup(sp->hdr.securityIndex); if (!sec) { - trace_rxrpc_abort(0, "SVS", - sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_INVALID_OPERATION, EKEYREJECTED); - skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; - skb->priority = RX_INVALID_OPERATION; + rxrpc_direct_abort(skb, rxrpc_abort_unsupported_security, + RX_INVALID_OPERATION, -EKEYREJECTED); return NULL; } if (sp->hdr.securityIndex != RXRPC_SECURITY_NONE && !rx->securities) { - trace_rxrpc_abort(0, "SVR", - sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_INVALID_OPERATION, EKEYREJECTED); - skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; - skb->priority = sec->no_key_abort; + rxrpc_direct_abort(skb, rxrpc_abort_no_service_key, + sec->no_key_abort, -EKEYREJECTED); return NULL; } @@ -161,9 +178,9 @@ struct key *rxrpc_look_up_server_security(struct rxrpc_connection *conn, sprintf(kdesc, "%u:%u", sp->hdr.serviceId, sp->hdr.securityIndex); - rcu_read_lock(); + read_lock(&conn->local->services_lock); - rx = rcu_dereference(conn->params.local->service); + rx = conn->local->service; if (!rx) goto out; @@ -185,6 +202,6 @@ struct key *rxrpc_look_up_server_security(struct rxrpc_connection *conn, } out: - rcu_read_unlock(); + read_unlock(&conn->local->services_lock); return key; } diff --git a/net/rxrpc/sendmsg.c b/net/rxrpc/sendmsg.c index 3c3a626459de..da49fcf1c456 100644 --- a/net/rxrpc/sendmsg.c +++ b/net/rxrpc/sendmsg.c @@ -18,18 +18,88 @@ #include "ar-internal.h" /* + * Propose an abort to be made in the I/O thread. + */ +bool rxrpc_propose_abort(struct rxrpc_call *call, s32 abort_code, int error, + enum rxrpc_abort_reason why) +{ + _enter("{%d},%d,%d,%u", call->debug_id, abort_code, error, why); + + if (!call->send_abort && !rxrpc_call_is_complete(call)) { + call->send_abort_why = why; + call->send_abort_err = error; + call->send_abort_seq = 0; + /* Request abort locklessly vs rxrpc_input_call_event(). */ + smp_store_release(&call->send_abort, abort_code); + rxrpc_poke_call(call, rxrpc_call_poke_abort); + return true; + } + + return false; +} + +/* + * Wait for a call to become connected. Interruption here doesn't cause the + * call to be aborted. + */ +static int rxrpc_wait_to_be_connected(struct rxrpc_call *call, long *timeo) +{ + DECLARE_WAITQUEUE(myself, current); + int ret = 0; + + _enter("%d", call->debug_id); + + if (rxrpc_call_state(call) != RXRPC_CALL_CLIENT_AWAIT_CONN) + return call->error; + + add_wait_queue_exclusive(&call->waitq, &myself); + + for (;;) { + ret = call->error; + if (ret < 0) + break; + + switch (call->interruptibility) { + case RXRPC_INTERRUPTIBLE: + case RXRPC_PREINTERRUPTIBLE: + set_current_state(TASK_INTERRUPTIBLE); + break; + case RXRPC_UNINTERRUPTIBLE: + default: + set_current_state(TASK_UNINTERRUPTIBLE); + break; + } + if (rxrpc_call_state(call) != RXRPC_CALL_CLIENT_AWAIT_CONN) { + ret = call->error; + break; + } + if ((call->interruptibility == RXRPC_INTERRUPTIBLE || + call->interruptibility == RXRPC_PREINTERRUPTIBLE) && + signal_pending(current)) { + ret = sock_intr_errno(*timeo); + break; + } + *timeo = schedule_timeout(*timeo); + } + + remove_wait_queue(&call->waitq, &myself); + __set_current_state(TASK_RUNNING); + + if (ret == 0 && rxrpc_call_is_complete(call)) + ret = call->error; + + _leave(" = %d", ret); + return ret; +} + +/* * Return true if there's sufficient Tx queue space. */ static bool rxrpc_check_tx_space(struct rxrpc_call *call, rxrpc_seq_t *_tx_win) { - unsigned int win_size = - min_t(unsigned int, call->tx_winsize, - call->cong_cwnd + call->cong_extra); - rxrpc_seq_t tx_win = READ_ONCE(call->tx_hard_ack); - if (_tx_win) - *_tx_win = tx_win; - return call->tx_top - tx_win < win_size; + *_tx_win = call->tx_bottom; + return call->tx_prepared - call->tx_bottom < 256; } /* @@ -44,13 +114,13 @@ static int rxrpc_wait_for_tx_window_intr(struct rxrpc_sock *rx, if (rxrpc_check_tx_space(call, NULL)) return 0; - if (call->state >= RXRPC_CALL_COMPLETE) + if (rxrpc_call_is_complete(call)) return call->error; if (signal_pending(current)) return sock_intr_errno(*timeo); - trace_rxrpc_transmit(call, rxrpc_transmit_wait); + trace_rxrpc_txqueue(call, rxrpc_txqueue_wait); *timeo = schedule_timeout(*timeo); } } @@ -71,16 +141,15 @@ static int rxrpc_wait_for_tx_window_waitall(struct rxrpc_sock *rx, rtt = 2; timeout = rtt; - tx_start = READ_ONCE(call->tx_hard_ack); + tx_start = smp_load_acquire(&call->acks_hard_ack); for (;;) { set_current_state(TASK_UNINTERRUPTIBLE); - tx_win = READ_ONCE(call->tx_hard_ack); if (rxrpc_check_tx_space(call, &tx_win)) return 0; - if (call->state >= RXRPC_CALL_COMPLETE) + if (rxrpc_call_is_complete(call)) return call->error; if (timeout == 0 && @@ -92,7 +161,7 @@ static int rxrpc_wait_for_tx_window_waitall(struct rxrpc_sock *rx, tx_start = tx_win; } - trace_rxrpc_transmit(call, rxrpc_transmit_wait); + trace_rxrpc_txqueue(call, rxrpc_txqueue_wait); timeout = schedule_timeout(timeout); } } @@ -109,10 +178,10 @@ static int rxrpc_wait_for_tx_window_nonintr(struct rxrpc_sock *rx, if (rxrpc_check_tx_space(call, NULL)) return 0; - if (call->state >= RXRPC_CALL_COMPLETE) + if (rxrpc_call_is_complete(call)) return call->error; - trace_rxrpc_transmit(call, rxrpc_transmit_wait); + trace_rxrpc_txqueue(call, rxrpc_txqueue_wait); *timeo = schedule_timeout(*timeo); } } @@ -129,8 +198,8 @@ static int rxrpc_wait_for_tx_window(struct rxrpc_sock *rx, DECLARE_WAITQUEUE(myself, current); int ret; - _enter(",{%u,%u,%u}", - call->tx_hard_ack, call->tx_top, call->tx_winsize); + _enter(",{%u,%u,%u,%u}", + call->tx_bottom, call->acks_hard_ack, call->tx_top, call->tx_winsize); add_wait_queue(&call->waitq, &myself); @@ -155,24 +224,6 @@ static int rxrpc_wait_for_tx_window(struct rxrpc_sock *rx, } /* - * Schedule an instant Tx resend. - */ -static inline void rxrpc_instant_resend(struct rxrpc_call *call, int ix) -{ - spin_lock_bh(&call->lock); - - if (call->state < RXRPC_CALL_COMPLETE) { - call->rxtx_annotations[ix] = - (call->rxtx_annotations[ix] & RXRPC_TX_ANNO_LAST) | - RXRPC_TX_ANNO_RETRANS; - if (!test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) - rxrpc_queue_call(call); - } - - spin_unlock_bh(&call->lock); -} - -/* * Notify the owner of the call that the transmit phase is ended and the last * packet has been queued. */ @@ -188,95 +239,38 @@ static void rxrpc_notify_end_tx(struct rxrpc_sock *rx, struct rxrpc_call *call, * the packet immediately. Returns the error from rxrpc_send_data_packet() * in case the caller wants to do something with it. */ -static int rxrpc_queue_packet(struct rxrpc_sock *rx, struct rxrpc_call *call, - struct sk_buff *skb, bool last, - rxrpc_notify_end_tx_t notify_end_tx) +static void rxrpc_queue_packet(struct rxrpc_sock *rx, struct rxrpc_call *call, + struct rxrpc_txbuf *txb, + rxrpc_notify_end_tx_t notify_end_tx) { - struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - unsigned long now; - rxrpc_seq_t seq = sp->hdr.seq; - int ret, ix; - u8 annotation = RXRPC_TX_ANNO_UNACK; + rxrpc_seq_t seq = txb->seq; + bool last = test_bit(RXRPC_TXBUF_LAST, &txb->flags), poke; - _net("queue skb %p [%d]", skb, seq); + rxrpc_inc_stat(call->rxnet, stat_tx_data); - ASSERTCMP(seq, ==, call->tx_top + 1); - - if (last) - annotation |= RXRPC_TX_ANNO_LAST; + ASSERTCMP(txb->seq, ==, call->tx_prepared + 1); /* We have to set the timestamp before queueing as the retransmit * algorithm can see the packet as soon as we queue it. */ - skb->tstamp = ktime_get_real(); - - ix = seq & RXRPC_RXTX_BUFF_MASK; - rxrpc_get_skb(skb, rxrpc_skb_got); - call->rxtx_annotations[ix] = annotation; - smp_wmb(); - call->rxtx_buffer[ix] = skb; - call->tx_top = seq; + txb->last_sent = ktime_get_real(); + if (last) - trace_rxrpc_transmit(call, rxrpc_transmit_queue_last); + trace_rxrpc_txqueue(call, rxrpc_txqueue_queue_last); else - trace_rxrpc_transmit(call, rxrpc_transmit_queue); - - if (last || call->state == RXRPC_CALL_SERVER_ACK_REQUEST) { - _debug("________awaiting reply/ACK__________"); - write_lock_bh(&call->state_lock); - switch (call->state) { - case RXRPC_CALL_CLIENT_SEND_REQUEST: - call->state = RXRPC_CALL_CLIENT_AWAIT_REPLY; - rxrpc_notify_end_tx(rx, call, notify_end_tx); - break; - case RXRPC_CALL_SERVER_ACK_REQUEST: - call->state = RXRPC_CALL_SERVER_SEND_REPLY; - now = jiffies; - WRITE_ONCE(call->ack_at, now + MAX_JIFFY_OFFSET); - if (call->ackr_reason == RXRPC_ACK_DELAY) - call->ackr_reason = 0; - trace_rxrpc_timer(call, rxrpc_timer_init_for_send_reply, now); - if (!last) - break; - fallthrough; - case RXRPC_CALL_SERVER_SEND_REPLY: - call->state = RXRPC_CALL_SERVER_AWAIT_ACK; - rxrpc_notify_end_tx(rx, call, notify_end_tx); - break; - default: - break; - } - write_unlock_bh(&call->state_lock); - } + trace_rxrpc_txqueue(call, rxrpc_txqueue_queue); - if (seq == 1 && rxrpc_is_client_call(call)) - rxrpc_expose_client_call(call); - - ret = rxrpc_send_data_packet(call, skb, false); - if (ret < 0) { - switch (ret) { - case -ENETUNREACH: - case -EHOSTUNREACH: - case -ECONNREFUSED: - rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, - 0, ret); - goto out; - } - _debug("need instant resend %d", ret); - rxrpc_instant_resend(call, ix); - } else { - unsigned long now = jiffies; - unsigned long resend_at = now + call->peer->rto_j; - - WRITE_ONCE(call->resend_at, resend_at); - rxrpc_reduce_call_timer(call, resend_at, now, - rxrpc_timer_set_for_send); - } + /* Add the packet to the call's output buffer */ + spin_lock(&call->tx_lock); + poke = list_empty(&call->tx_sendmsg); + list_add_tail(&txb->call_link, &call->tx_sendmsg); + call->tx_prepared = seq; + if (last) + rxrpc_notify_end_tx(rx, call, notify_end_tx); + spin_unlock(&call->tx_lock); -out: - rxrpc_free_skb(skb, rxrpc_skb_freed); - _leave(" = %d", ret); - return ret; + if (poke) + rxrpc_poke_call(call, rxrpc_call_poke_start); } /* @@ -290,8 +284,7 @@ static int rxrpc_send_data(struct rxrpc_sock *rx, rxrpc_notify_end_tx_t notify_end_tx, bool *_dropped_lock) { - struct rxrpc_skb_priv *sp; - struct sk_buff *skb; + struct rxrpc_txbuf *txb; struct sock *sk = &rx->sk; enum rxrpc_call_state state; long timeo; @@ -300,6 +293,16 @@ static int rxrpc_send_data(struct rxrpc_sock *rx, timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + ret = rxrpc_wait_to_be_connected(call, &timeo); + if (ret < 0) + return ret; + + if (call->conn->state == RXRPC_CONN_CLIENT_UNSECURED) { + ret = rxrpc_init_client_conn_security(call->conn); + if (ret < 0) + return ret; + } + /* this should be in poll */ sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); @@ -307,15 +310,20 @@ reload: ret = -EPIPE; if (sk->sk_shutdown & SEND_SHUTDOWN) goto maybe_error; - state = READ_ONCE(call->state); + state = rxrpc_call_state(call); ret = -ESHUTDOWN; if (state >= RXRPC_CALL_COMPLETE) goto maybe_error; ret = -EPROTO; if (state != RXRPC_CALL_CLIENT_SEND_REQUEST && state != RXRPC_CALL_SERVER_ACK_REQUEST && - state != RXRPC_CALL_SERVER_SEND_REPLY) + state != RXRPC_CALL_SERVER_SEND_REPLY) { + /* Request phase complete for this client call */ + trace_rxrpc_abort(call->debug_id, rxrpc_sendmsg_late_send, + call->cid, call->call_id, call->rx_consumed, + 0, -EPROTO); goto maybe_error; + } ret = -EMSGSIZE; if (call->tx_total_len != -1) { @@ -325,16 +333,13 @@ reload: goto maybe_error; } - skb = call->tx_pending; + txb = call->tx_pending; call->tx_pending = NULL; - rxrpc_see_skb(skb, rxrpc_skb_seen); + if (txb) + rxrpc_see_txbuf(txb, rxrpc_txbuf_see_send_more); do { - /* Check to see if there's a ping ACK to reply to. */ - if (call->ackr_reason == RXRPC_ACK_PING_RESPONSE) - rxrpc_send_ack_packet(call, false, NULL); - - if (!skb) { + if (!txb) { size_t remain, bufsize, chunk, offset; _debug("alloc"); @@ -355,53 +360,31 @@ reload: _debug("SIZE: %zu/%zu @%zu", chunk, bufsize, offset); /* create a buffer that we can retain until it's ACK'd */ - skb = sock_alloc_send_skb( - sk, bufsize, msg->msg_flags & MSG_DONTWAIT, &ret); - if (!skb) + ret = -ENOMEM; + txb = rxrpc_alloc_txbuf(call, RXRPC_PACKET_TYPE_DATA, + GFP_KERNEL); + if (!txb) goto maybe_error; - sp = rxrpc_skb(skb); - sp->rx_flags |= RXRPC_SKB_TX_BUFFER; - rxrpc_new_skb(skb, rxrpc_skb_new); - - _debug("ALLOC SEND %p", skb); - - ASSERTCMP(skb->mark, ==, 0); - - __skb_put(skb, offset); - - sp->remain = chunk; - if (sp->remain > skb_tailroom(skb)) - sp->remain = skb_tailroom(skb); - - _net("skb: hr %d, tr %d, hl %d, rm %d", - skb_headroom(skb), - skb_tailroom(skb), - skb_headlen(skb), - sp->remain); - - skb->ip_summed = CHECKSUM_UNNECESSARY; + txb->offset = offset; + txb->space -= offset; + txb->space = min_t(size_t, chunk, txb->space); } _debug("append"); - sp = rxrpc_skb(skb); /* append next segment of data to the current buffer */ if (msg_data_left(msg) > 0) { - int copy = skb_tailroom(skb); - ASSERTCMP(copy, >, 0); - if (copy > msg_data_left(msg)) - copy = msg_data_left(msg); - if (copy > sp->remain) - copy = sp->remain; - - _debug("add"); - ret = skb_add_data(skb, &msg->msg_iter, copy); - _debug("added"); - if (ret < 0) + size_t copy = min_t(size_t, txb->space, msg_data_left(msg)); + + _debug("add %zu", copy); + if (!copy_from_iter_full(txb->data + txb->offset, copy, + &msg->msg_iter)) goto efault; - sp->remain -= copy; - skb->mark += copy; + _debug("added"); + txb->space -= copy; + txb->len += copy; + txb->offset += copy; copied += copy; if (call->tx_total_len != -1) call->tx_total_len -= copy; @@ -409,54 +392,41 @@ reload: /* check for the far side aborting the call or a network error * occurring */ - if (call->state == RXRPC_CALL_COMPLETE) + if (rxrpc_call_is_complete(call)) goto call_terminated; /* add the packet to the send queue if it's now full */ - if (sp->remain <= 0 || + if (!txb->space || (msg_data_left(msg) == 0 && !more)) { - struct rxrpc_connection *conn = call->conn; - uint32_t seq; - - seq = call->tx_top + 1; - - sp->hdr.seq = seq; - sp->hdr._rsvd = 0; - sp->hdr.flags = conn->out_clientflag; - - if (msg_data_left(msg) == 0 && !more) - sp->hdr.flags |= RXRPC_LAST_PACKET; - else if (call->tx_top - call->tx_hard_ack < + if (msg_data_left(msg) == 0 && !more) { + txb->wire.flags |= RXRPC_LAST_PACKET; + __set_bit(RXRPC_TXBUF_LAST, &txb->flags); + } + else if (call->tx_top - call->acks_hard_ack < call->tx_winsize) - sp->hdr.flags |= RXRPC_MORE_PACKETS; + txb->wire.flags |= RXRPC_MORE_PACKETS; - ret = call->security->secure_packet(call, skb, skb->mark); + ret = call->security->secure_packet(call, txb); if (ret < 0) goto out; - ret = rxrpc_queue_packet(rx, call, skb, - !msg_data_left(msg) && !more, - notify_end_tx); - /* Should check for failure here */ - skb = NULL; + rxrpc_queue_packet(rx, call, txb, notify_end_tx); + txb = NULL; } } while (msg_data_left(msg) > 0); success: ret = copied; - if (READ_ONCE(call->state) == RXRPC_CALL_COMPLETE) { - read_lock_bh(&call->state_lock); - if (call->error < 0) - ret = call->error; - read_unlock_bh(&call->state_lock); - } + if (rxrpc_call_is_complete(call) && + call->error < 0) + ret = call->error; out: - call->tx_pending = skb; + call->tx_pending = txb; _leave(" = %d", ret); return ret; call_terminated: - rxrpc_free_skb(skb, rxrpc_skb_freed); + rxrpc_put_txbuf(txb, rxrpc_txbuf_put_send_aborted); _leave(" = %d", call->error); return call->error; @@ -633,7 +603,6 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, atomic_inc_return(&rxrpc_debug_id)); /* The socket is now unlocked */ - rxrpc_put_peer(cp.peer); _leave(" = %p\n", call); return call; } @@ -645,9 +614,7 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, */ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) __releases(&rx->sk.sk_lock.slock) - __releases(&call->user_mutex) { - enum rxrpc_call_state state; struct rxrpc_call *call; unsigned long now, j; bool dropped_lock = false; @@ -689,15 +656,15 @@ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) return PTR_ERR(call); /* ... and we have the call lock. */ ret = 0; - if (READ_ONCE(call->state) == RXRPC_CALL_COMPLETE) + if (rxrpc_call_is_complete(call)) goto out_put_unlock; } else { - switch (READ_ONCE(call->state)) { + switch (rxrpc_call_state(call)) { case RXRPC_CALL_UNINITIALISED: case RXRPC_CALL_CLIENT_AWAIT_CONN: case RXRPC_CALL_SERVER_PREALLOC: case RXRPC_CALL_SERVER_SECURING: - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_sendmsg); ret = -EBUSY; goto error_release_sock; default: @@ -716,7 +683,7 @@ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) if (call->tx_total_len != -1 || call->tx_pending || call->tx_top != 0) - goto error_put; + goto out_put_unlock; call->tx_total_len = p.call.tx_total_len; } } @@ -746,17 +713,13 @@ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) break; } - state = READ_ONCE(call->state); - _debug("CALL %d USR %lx ST %d on CONN %p", - call->debug_id, call->user_call_ID, state, call->conn); - - if (state >= RXRPC_CALL_COMPLETE) { + if (rxrpc_call_is_complete(call)) { /* it's too late for this call */ ret = -ESHUTDOWN; } else if (p.command == RXRPC_CMD_SEND_ABORT) { + rxrpc_propose_abort(call, p.abort_code, -ECONNABORTED, + rxrpc_abort_call_sendmsg); ret = 0; - if (rxrpc_abort_call("CMD", call, 0, p.abort_code, -ECONNABORTED)) - ret = rxrpc_send_abort_packet(call); } else if (p.command != RXRPC_CMD_SEND_DATA) { ret = -EINVAL; } else { @@ -767,7 +730,7 @@ out_put_unlock: if (!dropped_lock) mutex_unlock(&call->user_mutex); error_put: - rxrpc_put_call(call, rxrpc_call_put); + rxrpc_put_call(call, rxrpc_call_put_sendmsg); _leave(" = %d", ret); return ret; @@ -796,34 +759,17 @@ int rxrpc_kernel_send_data(struct socket *sock, struct rxrpc_call *call, bool dropped_lock = false; int ret; - _enter("{%d,%s},", call->debug_id, rxrpc_call_states[call->state]); + _enter("{%d},", call->debug_id); ASSERTCMP(msg->msg_name, ==, NULL); ASSERTCMP(msg->msg_control, ==, NULL); mutex_lock(&call->user_mutex); - _debug("CALL %d USR %lx ST %d on CONN %p", - call->debug_id, call->user_call_ID, call->state, call->conn); - - switch (READ_ONCE(call->state)) { - case RXRPC_CALL_CLIENT_SEND_REQUEST: - case RXRPC_CALL_SERVER_ACK_REQUEST: - case RXRPC_CALL_SERVER_SEND_REPLY: - ret = rxrpc_send_data(rxrpc_sk(sock->sk), call, msg, len, - notify_end_tx, &dropped_lock); - break; - case RXRPC_CALL_COMPLETE: - read_lock_bh(&call->state_lock); + ret = rxrpc_send_data(rxrpc_sk(sock->sk), call, msg, len, + notify_end_tx, &dropped_lock); + if (ret == -ESHUTDOWN) ret = call->error; - read_unlock_bh(&call->state_lock); - break; - default: - /* Request phase complete for this client call */ - trace_rxrpc_rx_eproto(call, 0, tracepoint_string("late_send")); - ret = -EPROTO; - break; - } if (!dropped_lock) mutex_unlock(&call->user_mutex); @@ -838,24 +784,20 @@ EXPORT_SYMBOL(rxrpc_kernel_send_data); * @call: The call to be aborted * @abort_code: The abort code to stick into the ABORT packet * @error: Local error value - * @why: 3-char string indicating why. + * @why: Indication as to why. * * Allow a kernel service to abort a call, if it's still in an abortable state * and return true if the call was aborted, false if it was already complete. */ bool rxrpc_kernel_abort_call(struct socket *sock, struct rxrpc_call *call, - u32 abort_code, int error, const char *why) + u32 abort_code, int error, enum rxrpc_abort_reason why) { bool aborted; - _enter("{%d},%d,%d,%s", call->debug_id, abort_code, error, why); + _enter("{%d},%d,%d,%u", call->debug_id, abort_code, error, why); mutex_lock(&call->user_mutex); - - aborted = rxrpc_abort_call(why, call, 0, abort_code, error); - if (aborted) - rxrpc_send_abort_packet(call); - + aborted = rxrpc_propose_abort(call, abort_code, error, why); mutex_unlock(&call->user_mutex); return aborted; } diff --git a/net/rxrpc/server_key.c b/net/rxrpc/server_key.c index ee269e0e6ee8..e51940589ee5 100644 --- a/net/rxrpc/server_key.c +++ b/net/rxrpc/server_key.c @@ -144,3 +144,28 @@ int rxrpc_server_keyring(struct rxrpc_sock *rx, sockptr_t optval, int optlen) _leave(" = 0 [key %x]", key->serial); return 0; } + +/** + * rxrpc_sock_set_security_keyring - Set the security keyring for a kernel service + * @sk: The socket to set the keyring on + * @keyring: The keyring to set + * + * Set the server security keyring on an rxrpc socket. This is used to provide + * the encryption keys for a kernel service. + */ +int rxrpc_sock_set_security_keyring(struct sock *sk, struct key *keyring) +{ + struct rxrpc_sock *rx = rxrpc_sk(sk); + int ret = 0; + + lock_sock(sk); + if (rx->securities) + ret = -EINVAL; + else if (rx->sk.sk_state != RXRPC_UNBOUND) + ret = -EISCONN; + else + rx->securities = key_get(keyring); + release_sock(sk); + return ret; +} +EXPORT_SYMBOL(rxrpc_sock_set_security_keyring); diff --git a/net/rxrpc/skbuff.c b/net/rxrpc/skbuff.c index 580a5acffee7..ebe0c75e7b07 100644 --- a/net/rxrpc/skbuff.c +++ b/net/rxrpc/skbuff.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0-or-later -/* ar-skbuff.c: socket buffer destruction handling +/* Socket buffer accounting * * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. * Written by David Howells (dhowells@redhat.com) @@ -14,66 +14,55 @@ #include <net/af_rxrpc.h> #include "ar-internal.h" -#define is_tx_skb(skb) (rxrpc_skb(skb)->rx_flags & RXRPC_SKB_TX_BUFFER) -#define select_skb_count(skb) (is_tx_skb(skb) ? &rxrpc_n_tx_skbs : &rxrpc_n_rx_skbs) +#define select_skb_count(skb) (&rxrpc_n_rx_skbs) /* * Note the allocation or reception of a socket buffer. */ -void rxrpc_new_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +void rxrpc_new_skb(struct sk_buff *skb, enum rxrpc_skb_trace why) { - const void *here = __builtin_return_address(0); int n = atomic_inc_return(select_skb_count(skb)); - trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, - rxrpc_skb(skb)->rx_flags, here); + trace_rxrpc_skb(skb, refcount_read(&skb->users), n, why); } /* * Note the re-emergence of a socket buffer from a queue or buffer. */ -void rxrpc_see_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +void rxrpc_see_skb(struct sk_buff *skb, enum rxrpc_skb_trace why) { - const void *here = __builtin_return_address(0); if (skb) { int n = atomic_read(select_skb_count(skb)); - trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, - rxrpc_skb(skb)->rx_flags, here); + trace_rxrpc_skb(skb, refcount_read(&skb->users), n, why); } } /* * Note the addition of a ref on a socket buffer. */ -void rxrpc_get_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +void rxrpc_get_skb(struct sk_buff *skb, enum rxrpc_skb_trace why) { - const void *here = __builtin_return_address(0); int n = atomic_inc_return(select_skb_count(skb)); - trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, - rxrpc_skb(skb)->rx_flags, here); + trace_rxrpc_skb(skb, refcount_read(&skb->users), n, why); skb_get(skb); } /* * Note the dropping of a ref on a socket buffer by the core. */ -void rxrpc_eaten_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +void rxrpc_eaten_skb(struct sk_buff *skb, enum rxrpc_skb_trace why) { - const void *here = __builtin_return_address(0); int n = atomic_inc_return(&rxrpc_n_rx_skbs); - trace_rxrpc_skb(skb, op, 0, n, 0, here); + trace_rxrpc_skb(skb, 0, n, why); } /* * Note the destruction of a socket buffer. */ -void rxrpc_free_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +void rxrpc_free_skb(struct sk_buff *skb, enum rxrpc_skb_trace why) { - const void *here = __builtin_return_address(0); if (skb) { - int n; - n = atomic_dec_return(select_skb_count(skb)); - trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, - rxrpc_skb(skb)->rx_flags, here); + int n = atomic_dec_return(select_skb_count(skb)); + trace_rxrpc_skb(skb, refcount_read(&skb->users), n, why); kfree_skb(skb); } } @@ -83,13 +72,12 @@ void rxrpc_free_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) */ void rxrpc_purge_queue(struct sk_buff_head *list) { - const void *here = __builtin_return_address(0); struct sk_buff *skb; + while ((skb = skb_dequeue((list))) != NULL) { int n = atomic_dec_return(select_skb_count(skb)); - trace_rxrpc_skb(skb, rxrpc_skb_purged, - refcount_read(&skb->users), n, - rxrpc_skb(skb)->rx_flags, here); + trace_rxrpc_skb(skb, refcount_read(&skb->users), n, + rxrpc_skb_put_purge); kfree_skb(skb); } } diff --git a/net/rxrpc/sysctl.c b/net/rxrpc/sysctl.c index 555e0910786b..cde3224a5cd2 100644 --- a/net/rxrpc/sysctl.c +++ b/net/rxrpc/sysctl.c @@ -14,7 +14,7 @@ static struct ctl_table_header *rxrpc_sysctl_reg_table; static const unsigned int four = 4; static const unsigned int max_backlog = RXRPC_BACKLOG_MAX - 1; static const unsigned int n_65535 = 65535; -static const unsigned int n_max_acks = RXRPC_RXTX_BUFF_SIZE - 1; +static const unsigned int n_max_acks = 255; static const unsigned long one_jiffy = 1; static const unsigned long max_jiffies = MAX_JIFFY_OFFSET; @@ -27,15 +27,6 @@ static const unsigned long max_jiffies = MAX_JIFFY_OFFSET; static struct ctl_table rxrpc_sysctl_table[] = { /* Values measured in milliseconds but used in jiffies */ { - .procname = "req_ack_delay", - .data = &rxrpc_requested_ack_delay, - .maxlen = sizeof(unsigned long), - .mode = 0644, - .proc_handler = proc_doulongvec_ms_jiffies_minmax, - .extra1 = (void *)&one_jiffy, - .extra2 = (void *)&max_jiffies, - }, - { .procname = "soft_ack_delay", .data = &rxrpc_soft_ack_delay, .maxlen = sizeof(unsigned long), diff --git a/net/rxrpc/txbuf.c b/net/rxrpc/txbuf.c new file mode 100644 index 000000000000..d2cf2aac3adb --- /dev/null +++ b/net/rxrpc/txbuf.c @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC Tx data buffering. + * + * Copyright (C) 2022 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include <linux/slab.h> +#include "ar-internal.h" + +static atomic_t rxrpc_txbuf_debug_ids; +atomic_t rxrpc_nr_txbuf; + +/* + * Allocate and partially initialise an I/O request structure. + */ +struct rxrpc_txbuf *rxrpc_alloc_txbuf(struct rxrpc_call *call, u8 packet_type, + gfp_t gfp) +{ + struct rxrpc_txbuf *txb; + + txb = kmalloc(sizeof(*txb), gfp); + if (txb) { + INIT_LIST_HEAD(&txb->call_link); + INIT_LIST_HEAD(&txb->tx_link); + refcount_set(&txb->ref, 1); + txb->call_debug_id = call->debug_id; + txb->debug_id = atomic_inc_return(&rxrpc_txbuf_debug_ids); + txb->space = sizeof(txb->data); + txb->len = 0; + txb->offset = 0; + txb->flags = 0; + txb->ack_why = 0; + txb->seq = call->tx_prepared + 1; + txb->wire.epoch = htonl(call->conn->proto.epoch); + txb->wire.cid = htonl(call->cid); + txb->wire.callNumber = htonl(call->call_id); + txb->wire.seq = htonl(txb->seq); + txb->wire.type = packet_type; + txb->wire.flags = call->conn->out_clientflag; + txb->wire.userStatus = 0; + txb->wire.securityIndex = call->security_ix; + txb->wire._rsvd = 0; + txb->wire.serviceId = htons(call->dest_srx.srx_service); + + trace_rxrpc_txbuf(txb->debug_id, + txb->call_debug_id, txb->seq, 1, + packet_type == RXRPC_PACKET_TYPE_DATA ? + rxrpc_txbuf_alloc_data : + rxrpc_txbuf_alloc_ack); + atomic_inc(&rxrpc_nr_txbuf); + } + + return txb; +} + +void rxrpc_get_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what) +{ + int r; + + __refcount_inc(&txb->ref, &r); + trace_rxrpc_txbuf(txb->debug_id, txb->call_debug_id, txb->seq, r + 1, what); +} + +void rxrpc_see_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what) +{ + int r = refcount_read(&txb->ref); + + trace_rxrpc_txbuf(txb->debug_id, txb->call_debug_id, txb->seq, r, what); +} + +static void rxrpc_free_txbuf(struct rcu_head *rcu) +{ + struct rxrpc_txbuf *txb = container_of(rcu, struct rxrpc_txbuf, rcu); + + trace_rxrpc_txbuf(txb->debug_id, txb->call_debug_id, txb->seq, 0, + rxrpc_txbuf_free); + kfree(txb); + atomic_dec(&rxrpc_nr_txbuf); +} + +void rxrpc_put_txbuf(struct rxrpc_txbuf *txb, enum rxrpc_txbuf_trace what) +{ + unsigned int debug_id, call_debug_id; + rxrpc_seq_t seq; + bool dead; + int r; + + if (txb) { + debug_id = txb->debug_id; + call_debug_id = txb->call_debug_id; + seq = txb->seq; + dead = __refcount_dec_and_test(&txb->ref, &r); + trace_rxrpc_txbuf(debug_id, call_debug_id, seq, r - 1, what); + if (dead) + call_rcu(&txb->rcu, rxrpc_free_txbuf); + } +} + +/* + * Shrink the transmit buffer. + */ +void rxrpc_shrink_call_tx_buffer(struct rxrpc_call *call) +{ + struct rxrpc_txbuf *txb; + rxrpc_seq_t hard_ack = smp_load_acquire(&call->acks_hard_ack); + bool wake = false; + + _enter("%x/%x/%x", call->tx_bottom, call->acks_hard_ack, call->tx_top); + + for (;;) { + spin_lock(&call->tx_lock); + txb = list_first_entry_or_null(&call->tx_buffer, + struct rxrpc_txbuf, call_link); + if (!txb) + break; + hard_ack = smp_load_acquire(&call->acks_hard_ack); + if (before(hard_ack, txb->seq)) + break; + + if (txb->seq != call->tx_bottom + 1) + rxrpc_see_txbuf(txb, rxrpc_txbuf_see_out_of_step); + ASSERTCMP(txb->seq, ==, call->tx_bottom + 1); + smp_store_release(&call->tx_bottom, call->tx_bottom + 1); + list_del_rcu(&txb->call_link); + + trace_rxrpc_txqueue(call, rxrpc_txqueue_dequeue); + + spin_unlock(&call->tx_lock); + + rxrpc_put_txbuf(txb, rxrpc_txbuf_put_rotated); + if (after(call->acks_hard_ack, call->tx_bottom + 128)) + wake = true; + } + + spin_unlock(&call->tx_lock); + + if (wake) + wake_up(&call->waitq); +} diff --git a/net/sched/Kconfig b/net/sched/Kconfig index 1e8ab4749c6c..777d6b50505c 100644 --- a/net/sched/Kconfig +++ b/net/sched/Kconfig @@ -976,7 +976,8 @@ config NET_ACT_TUNNEL_KEY config NET_ACT_CT tristate "connection tracking tc action" - depends on NET_CLS_ACT && NF_CONNTRACK && NF_NAT && NF_FLOW_TABLE + depends on NET_CLS_ACT && NF_CONNTRACK && (!NF_NAT || NF_NAT) && NF_FLOW_TABLE + select NF_NAT_OVS if NF_NAT help Say Y here to allow sending the packets to conntrack module. diff --git a/net/sched/act_api.c b/net/sched/act_api.c index 9b31a10cc639..5b3c0ac495be 100644 --- a/net/sched/act_api.c +++ b/net/sched/act_api.c @@ -23,6 +23,7 @@ #include <net/act_api.h> #include <net/netlink.h> #include <net/flow_offload.h> +#include <net/tc_wrapper.h> #ifdef CONFIG_INET DEFINE_STATIC_KEY_FALSE(tcf_frag_xmit_count); @@ -1080,7 +1081,7 @@ restart_act_graph: repeat_ttl = 32; repeat: - ret = a->ops->act(skb, a, res); + ret = tc_act(skb, a, res); if (unlikely(ret == TC_ACT_REPEAT)) { if (--repeat_ttl != 0) goto repeat; diff --git a/net/sched/act_bpf.c b/net/sched/act_bpf.c index b79eee44e24e..b0455fda7d0b 100644 --- a/net/sched/act_bpf.c +++ b/net/sched/act_bpf.c @@ -18,6 +18,7 @@ #include <linux/tc_act/tc_bpf.h> #include <net/tc_act/tc_bpf.h> +#include <net/tc_wrapper.h> #define ACT_BPF_NAME_LEN 256 @@ -31,8 +32,9 @@ struct tcf_bpf_cfg { static struct tc_action_ops act_bpf_ops; -static int tcf_bpf_act(struct sk_buff *skb, const struct tc_action *act, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_bpf_act(struct sk_buff *skb, + const struct tc_action *act, + struct tcf_result *res) { bool at_ingress = skb_at_tc_ingress(skb); struct tcf_bpf *prog = to_bpf(act); diff --git a/net/sched/act_connmark.c b/net/sched/act_connmark.c index 66b143bb04ac..7e63ff7e3ed7 100644 --- a/net/sched/act_connmark.c +++ b/net/sched/act_connmark.c @@ -20,6 +20,7 @@ #include <net/pkt_cls.h> #include <uapi/linux/tc_act/tc_connmark.h> #include <net/tc_act/tc_connmark.h> +#include <net/tc_wrapper.h> #include <net/netfilter/nf_conntrack.h> #include <net/netfilter/nf_conntrack_core.h> @@ -27,8 +28,9 @@ static struct tc_action_ops act_connmark_ops; -static int tcf_connmark_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_connmark_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { const struct nf_conntrack_tuple_hash *thash; struct nf_conntrack_tuple tuple; @@ -61,7 +63,7 @@ static int tcf_connmark_act(struct sk_buff *skb, const struct tc_action *a, c = nf_ct_get(skb, &ctinfo); if (c) { - skb->mark = c->mark; + skb->mark = READ_ONCE(c->mark); /* using overlimits stats to count how many packets marked */ ca->tcf_qstats.overlimits++; goto out; @@ -81,7 +83,7 @@ static int tcf_connmark_act(struct sk_buff *skb, const struct tc_action *a, c = nf_ct_tuplehash_to_ctrack(thash); /* using overlimits stats to count how many packets marked */ ca->tcf_qstats.overlimits++; - skb->mark = c->mark; + skb->mark = READ_ONCE(c->mark); nf_ct_put(c); out: diff --git a/net/sched/act_csum.c b/net/sched/act_csum.c index 1366adf9b909..95e9304024b7 100644 --- a/net/sched/act_csum.c +++ b/net/sched/act_csum.c @@ -32,6 +32,7 @@ #include <linux/tc_act/tc_csum.h> #include <net/tc_act/tc_csum.h> +#include <net/tc_wrapper.h> static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, @@ -563,8 +564,9 @@ fail: return 0; } -static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_csum_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_csum *p = to_tcf_csum(a); bool orig_vlan_tag_present = false; diff --git a/net/sched/act_ct.c b/net/sched/act_ct.c index b38d91d6b249..0ca2bb8ed026 100644 --- a/net/sched/act_ct.c +++ b/net/sched/act_ct.c @@ -24,6 +24,7 @@ #include <net/ipv6_frag.h> #include <uapi/linux/tc_act/tc_ct.h> #include <net/tc_act/tc_ct.h> +#include <net/tc_wrapper.h> #include <net/netfilter/nf_flow_table.h> #include <net/netfilter/nf_conntrack.h> @@ -33,6 +34,7 @@ #include <net/netfilter/nf_conntrack_acct.h> #include <net/netfilter/ipv6/nf_defrag_ipv6.h> #include <net/netfilter/nf_conntrack_act_ct.h> +#include <net/netfilter/nf_conntrack_seqadj.h> #include <uapi/linux/netfilter/nf_nat.h> static struct workqueue_struct *act_ct_wq; @@ -178,7 +180,7 @@ static void tcf_ct_flow_table_add_action_meta(struct nf_conn *ct, entry = tcf_ct_flow_table_flow_action_get_next(action); entry->id = FLOW_ACTION_CT_METADATA; #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) - entry->ct_metadata.mark = ct->mark; + entry->ct_metadata.mark = READ_ONCE(ct->mark); #endif ctinfo = dir == IP_CT_DIR_ORIGINAL ? IP_CT_ESTABLISHED : IP_CT_ESTABLISHED_REPLY; @@ -345,11 +347,9 @@ static void tcf_ct_flow_table_cleanup_work(struct work_struct *work) module_put(THIS_MODULE); } -static void tcf_ct_flow_table_put(struct tcf_ct_params *params) +static void tcf_ct_flow_table_put(struct tcf_ct_flow_table *ct_ft) { - struct tcf_ct_flow_table *ct_ft = params->ct_ft; - - if (refcount_dec_and_test(¶ms->ct_ft->ref)) { + if (refcount_dec_and_test(&ct_ft->ref)) { rhashtable_remove_fast(&zones_ht, &ct_ft->node, zones_params); INIT_RCU_WORK(&ct_ft->rwork, tcf_ct_flow_table_cleanup_work); queue_rcu_work(act_ct_wq, &ct_ft->rwork); @@ -657,7 +657,7 @@ struct tc_ct_action_net { /* Determine whether skb->_nfct is equal to the result of conntrack lookup. */ static bool tcf_ct_skb_nfct_cached(struct net *net, struct sk_buff *skb, - u16 zone_id, bool force) + struct tcf_ct_params *p) { enum ip_conntrack_info ctinfo; struct nf_conn *ct; @@ -667,11 +667,19 @@ static bool tcf_ct_skb_nfct_cached(struct net *net, struct sk_buff *skb, return false; if (!net_eq(net, read_pnet(&ct->ct_net))) goto drop_ct; - if (nf_ct_zone(ct)->id != zone_id) + if (nf_ct_zone(ct)->id != p->zone) goto drop_ct; + if (p->helper) { + struct nf_conn_help *help; + + help = nf_ct_ext_find(ct, NF_CT_EXT_HELPER); + if (help && rcu_access_pointer(help->helper) != p->helper) + goto drop_ct; + } /* Force conntrack entry direction. */ - if (force && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { + if ((p->ct_action & TCA_CT_ACT_FORCE) && + CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { if (nf_ct_is_confirmed(ct)) nf_ct_kill(ct); @@ -832,101 +840,29 @@ out_free: return err; } -static void tcf_ct_params_free(struct rcu_head *head) +static void tcf_ct_params_free(struct tcf_ct_params *params) { - struct tcf_ct_params *params = container_of(head, - struct tcf_ct_params, rcu); - - tcf_ct_flow_table_put(params); - + if (params->helper) { +#if IS_ENABLED(CONFIG_NF_NAT) + if (params->ct_action & TCA_CT_ACT_NAT) + nf_nat_helper_put(params->helper); +#endif + nf_conntrack_helper_put(params->helper); + } + if (params->ct_ft) + tcf_ct_flow_table_put(params->ct_ft); if (params->tmpl) nf_ct_put(params->tmpl); kfree(params); } -#if IS_ENABLED(CONFIG_NF_NAT) -/* Modelled after nf_nat_ipv[46]_fn(). - * range is only used for new, uninitialized NAT state. - * Returns either NF_ACCEPT or NF_DROP. - */ -static int ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, - enum ip_conntrack_info ctinfo, - const struct nf_nat_range2 *range, - enum nf_nat_manip_type maniptype) +static void tcf_ct_params_free_rcu(struct rcu_head *head) { - __be16 proto = skb_protocol(skb, true); - int hooknum, err = NF_ACCEPT; - - /* See HOOK2MANIP(). */ - if (maniptype == NF_NAT_MANIP_SRC) - hooknum = NF_INET_LOCAL_IN; /* Source NAT */ - else - hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */ - - switch (ctinfo) { - case IP_CT_RELATED: - case IP_CT_RELATED_REPLY: - if (proto == htons(ETH_P_IP) && - ip_hdr(skb)->protocol == IPPROTO_ICMP) { - if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, - hooknum)) - err = NF_DROP; - goto out; - } else if (IS_ENABLED(CONFIG_IPV6) && proto == htons(ETH_P_IPV6)) { - __be16 frag_off; - u8 nexthdr = ipv6_hdr(skb)->nexthdr; - int hdrlen = ipv6_skip_exthdr(skb, - sizeof(struct ipv6hdr), - &nexthdr, &frag_off); - - if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { - if (!nf_nat_icmpv6_reply_translation(skb, ct, - ctinfo, - hooknum, - hdrlen)) - err = NF_DROP; - goto out; - } - } - /* Non-ICMP, fall thru to initialize if needed. */ - fallthrough; - case IP_CT_NEW: - /* Seen it before? This can happen for loopback, retrans, - * or local packets. - */ - if (!nf_nat_initialized(ct, maniptype)) { - /* Initialize according to the NAT action. */ - err = (range && range->flags & NF_NAT_RANGE_MAP_IPS) - /* Action is set up to establish a new - * mapping. - */ - ? nf_nat_setup_info(ct, range, maniptype) - : nf_nat_alloc_null_binding(ct, hooknum); - if (err != NF_ACCEPT) - goto out; - } - break; - - case IP_CT_ESTABLISHED: - case IP_CT_ESTABLISHED_REPLY: - break; - - default: - err = NF_DROP; - goto out; - } + struct tcf_ct_params *params; - err = nf_nat_packet(ct, ctinfo, hooknum, skb); - if (err == NF_ACCEPT) { - if (maniptype == NF_NAT_MANIP_SRC) - tc_skb_cb(skb)->post_ct_snat = 1; - if (maniptype == NF_NAT_MANIP_DST) - tc_skb_cb(skb)->post_ct_dnat = 1; - } -out: - return err; + params = container_of(head, struct tcf_ct_params, rcu); + tcf_ct_params_free(params); } -#endif /* CONFIG_NF_NAT */ static void tcf_ct_act_set_mark(struct nf_conn *ct, u32 mark, u32 mask) { @@ -936,9 +872,9 @@ static void tcf_ct_act_set_mark(struct nf_conn *ct, u32 mark, u32 mask) if (!mask) return; - new_mark = mark | (ct->mark & ~(mask)); - if (ct->mark != new_mark) { - ct->mark = new_mark; + new_mark = mark | (READ_ONCE(ct->mark) & ~(mask)); + if (READ_ONCE(ct->mark) != new_mark) { + WRITE_ONCE(ct->mark, new_mark); if (nf_ct_is_confirmed(ct)) nf_conntrack_event_cache(IPCT_MARK, ct); } @@ -967,69 +903,40 @@ static int tcf_ct_act_nat(struct sk_buff *skb, bool commit) { #if IS_ENABLED(CONFIG_NF_NAT) - int err; - enum nf_nat_manip_type maniptype; + int err, action = 0; if (!(ct_action & TCA_CT_ACT_NAT)) return NF_ACCEPT; + if (ct_action & TCA_CT_ACT_NAT_SRC) + action |= BIT(NF_NAT_MANIP_SRC); + if (ct_action & TCA_CT_ACT_NAT_DST) + action |= BIT(NF_NAT_MANIP_DST); - /* Add NAT extension if not confirmed yet. */ - if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct)) - return NF_DROP; /* Can't NAT. */ - - if (ctinfo != IP_CT_NEW && (ct->status & IPS_NAT_MASK) && - (ctinfo != IP_CT_RELATED || commit)) { - /* NAT an established or related connection like before. */ - if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) - /* This is the REPLY direction for a connection - * for which NAT was applied in the forward - * direction. Do the reverse NAT. - */ - maniptype = ct->status & IPS_SRC_NAT - ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC; - else - maniptype = ct->status & IPS_SRC_NAT - ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST; - } else if (ct_action & TCA_CT_ACT_NAT_SRC) { - maniptype = NF_NAT_MANIP_SRC; - } else if (ct_action & TCA_CT_ACT_NAT_DST) { - maniptype = NF_NAT_MANIP_DST; - } else { - return NF_ACCEPT; - } + err = nf_ct_nat(skb, ct, ctinfo, &action, range, commit); + + if (action & BIT(NF_NAT_MANIP_SRC)) + tc_skb_cb(skb)->post_ct_snat = 1; + if (action & BIT(NF_NAT_MANIP_DST)) + tc_skb_cb(skb)->post_ct_dnat = 1; - err = ct_nat_execute(skb, ct, ctinfo, range, maniptype); - if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) { - if (ct->status & IPS_SRC_NAT) { - if (maniptype == NF_NAT_MANIP_SRC) - maniptype = NF_NAT_MANIP_DST; - else - maniptype = NF_NAT_MANIP_SRC; - - err = ct_nat_execute(skb, ct, ctinfo, range, - maniptype); - } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { - err = ct_nat_execute(skb, ct, ctinfo, NULL, - NF_NAT_MANIP_SRC); - } - } return err; #else return NF_ACCEPT; #endif } -static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) { struct net *net = dev_net(skb->dev); - bool cached, commit, clear, force; enum ip_conntrack_info ctinfo; struct tcf_ct *c = to_ct(a); struct nf_conn *tmpl = NULL; struct nf_hook_state state; + bool cached, commit, clear; int nh_ofs, err, retval; struct tcf_ct_params *p; + bool add_helper = false; bool skip_add = false; bool defrag = false; struct nf_conn *ct; @@ -1040,7 +947,6 @@ static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a, retval = READ_ONCE(c->tcf_action); commit = p->ct_action & TCA_CT_ACT_COMMIT; clear = p->ct_action & TCA_CT_ACT_CLEAR; - force = p->ct_action & TCA_CT_ACT_FORCE; tmpl = p->tmpl; tcf_lastuse_update(&c->tcf_tm); @@ -1083,7 +989,7 @@ static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a, * actually run the packet through conntrack twice unless it's for a * different zone. */ - cached = tcf_ct_skb_nfct_cached(net, skb, p->zone, force); + cached = tcf_ct_skb_nfct_cached(net, skb, p); if (!cached) { if (tcf_ct_flow_table_lookup(p, skb, family)) { skip_add = true; @@ -1116,6 +1022,22 @@ do_nat: if (err != NF_ACCEPT) goto drop; + if (!nf_ct_is_confirmed(ct) && commit && p->helper && !nfct_help(ct)) { + err = __nf_ct_try_assign_helper(ct, p->tmpl, GFP_ATOMIC); + if (err) + goto drop; + add_helper = true; + if (p->ct_action & TCA_CT_ACT_NAT && !nfct_seqadj(ct)) { + if (!nfct_seqadj_ext_add(ct)) + goto drop; + } + } + + if (nf_ct_is_confirmed(ct) ? ((!cached && !skip_add) || add_helper) : commit) { + if (nf_ct_helper(skb, ct, ctinfo, family) != NF_ACCEPT) + goto drop; + } + if (commit) { tcf_ct_act_set_mark(ct, p->mark, p->mark_mask); tcf_ct_act_set_labels(ct, p->labels, p->labels_mask); @@ -1164,6 +1086,9 @@ static const struct nla_policy ct_policy[TCA_CT_MAX + 1] = { [TCA_CT_NAT_IPV6_MAX] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), [TCA_CT_NAT_PORT_MIN] = { .type = NLA_U16 }, [TCA_CT_NAT_PORT_MAX] = { .type = NLA_U16 }, + [TCA_CT_HELPER_NAME] = { .type = NLA_STRING, .len = NF_CT_HELPER_NAME_LEN }, + [TCA_CT_HELPER_FAMILY] = { .type = NLA_U8 }, + [TCA_CT_HELPER_PROTO] = { .type = NLA_U8 }, }; static int tcf_ct_fill_params_nat(struct tcf_ct_params *p, @@ -1253,8 +1178,9 @@ static int tcf_ct_fill_params(struct net *net, { struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id); struct nf_conntrack_zone zone; + int err, family, proto, len; struct nf_conn *tmpl; - int err; + char *name; p->zone = NF_CT_DEFAULT_ZONE_ID; @@ -1315,10 +1241,31 @@ static int tcf_ct_fill_params(struct net *net, NL_SET_ERR_MSG_MOD(extack, "Failed to allocate conntrack template"); return -ENOMEM; } - __set_bit(IPS_CONFIRMED_BIT, &tmpl->status); p->tmpl = tmpl; + if (tb[TCA_CT_HELPER_NAME]) { + name = nla_data(tb[TCA_CT_HELPER_NAME]); + len = nla_len(tb[TCA_CT_HELPER_NAME]); + if (len > 16 || name[len - 1] != '\0') { + NL_SET_ERR_MSG_MOD(extack, "Failed to parse helper name."); + err = -EINVAL; + goto err; + } + family = tb[TCA_CT_HELPER_FAMILY] ? nla_get_u8(tb[TCA_CT_HELPER_FAMILY]) : AF_INET; + proto = tb[TCA_CT_HELPER_PROTO] ? nla_get_u8(tb[TCA_CT_HELPER_PROTO]) : IPPROTO_TCP; + err = nf_ct_add_helper(tmpl, name, family, proto, + p->ct_action & TCA_CT_ACT_NAT, &p->helper); + if (err) { + NL_SET_ERR_MSG_MOD(extack, "Failed to add helper"); + goto err; + } + } + __set_bit(IPS_CONFIRMED_BIT, &tmpl->status); return 0; +err: + nf_ct_put(p->tmpl); + p->tmpl = NULL; + return err; } static int tcf_ct_init(struct net *net, struct nlattr *nla, @@ -1390,7 +1337,7 @@ static int tcf_ct_init(struct net *net, struct nlattr *nla, err = tcf_ct_flow_table_get(net, params); if (err) - goto cleanup_params; + goto cleanup; spin_lock_bh(&c->tcf_lock); goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); @@ -1401,17 +1348,15 @@ static int tcf_ct_init(struct net *net, struct nlattr *nla, if (goto_ch) tcf_chain_put_by_act(goto_ch); if (params) - call_rcu(¶ms->rcu, tcf_ct_params_free); + call_rcu(¶ms->rcu, tcf_ct_params_free_rcu); return res; -cleanup_params: - if (params->tmpl) - nf_ct_put(params->tmpl); cleanup: if (goto_ch) tcf_chain_put_by_act(goto_ch); - kfree(params); + if (params) + tcf_ct_params_free(params); tcf_idr_release(*a, bind); return err; } @@ -1423,7 +1368,7 @@ static void tcf_ct_cleanup(struct tc_action *a) params = rcu_dereference_protected(c->params, 1); if (params) - call_rcu(¶ms->rcu, tcf_ct_params_free); + call_rcu(¶ms->rcu, tcf_ct_params_free_rcu); } static int tcf_ct_dump_key_val(struct sk_buff *skb, @@ -1489,6 +1434,19 @@ static int tcf_ct_dump_nat(struct sk_buff *skb, struct tcf_ct_params *p) return 0; } +static int tcf_ct_dump_helper(struct sk_buff *skb, struct nf_conntrack_helper *helper) +{ + if (!helper) + return 0; + + if (nla_put_string(skb, TCA_CT_HELPER_NAME, helper->name) || + nla_put_u8(skb, TCA_CT_HELPER_FAMILY, helper->tuple.src.l3num) || + nla_put_u8(skb, TCA_CT_HELPER_PROTO, helper->tuple.dst.protonum)) + return -1; + + return 0; +} + static inline int tcf_ct_dump(struct sk_buff *skb, struct tc_action *a, int bind, int ref) { @@ -1541,6 +1499,9 @@ static inline int tcf_ct_dump(struct sk_buff *skb, struct tc_action *a, if (tcf_ct_dump_nat(skb, p)) goto nla_put_failure; + if (tcf_ct_dump_helper(skb, p->helper)) + goto nla_put_failure; + skip_dump: if (nla_put(skb, TCA_CT_PARMS, sizeof(opt), &opt)) goto nla_put_failure; diff --git a/net/sched/act_ctinfo.c b/net/sched/act_ctinfo.c index d4102f0a9abd..4b1b59da5c0b 100644 --- a/net/sched/act_ctinfo.c +++ b/net/sched/act_ctinfo.c @@ -18,6 +18,7 @@ #include <net/pkt_cls.h> #include <uapi/linux/tc_act/tc_ctinfo.h> #include <net/tc_act/tc_ctinfo.h> +#include <net/tc_wrapper.h> #include <net/netfilter/nf_conntrack.h> #include <net/netfilter/nf_conntrack_core.h> @@ -32,7 +33,7 @@ static void tcf_ctinfo_dscp_set(struct nf_conn *ct, struct tcf_ctinfo *ca, { u8 dscp, newdscp; - newdscp = (((ct->mark & cp->dscpmask) >> cp->dscpmaskshift) << 2) & + newdscp = (((READ_ONCE(ct->mark) & cp->dscpmask) >> cp->dscpmaskshift) << 2) & ~INET_ECN_MASK; switch (proto) { @@ -72,11 +73,12 @@ static void tcf_ctinfo_cpmark_set(struct nf_conn *ct, struct tcf_ctinfo *ca, struct sk_buff *skb) { ca->stats_cpmark_set++; - skb->mark = ct->mark & cp->cpmarkmask; + skb->mark = READ_ONCE(ct->mark) & cp->cpmarkmask; } -static int tcf_ctinfo_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_ctinfo_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { const struct nf_conntrack_tuple_hash *thash = NULL; struct tcf_ctinfo *ca = to_ctinfo(a); @@ -130,7 +132,7 @@ static int tcf_ctinfo_act(struct sk_buff *skb, const struct tc_action *a, } if (cp->mode & CTINFO_MODE_DSCP) - if (!cp->dscpstatemask || (ct->mark & cp->dscpstatemask)) + if (!cp->dscpstatemask || (READ_ONCE(ct->mark) & cp->dscpstatemask)) tcf_ctinfo_dscp_set(ct, ca, cp, skb, wlen, proto); if (cp->mode & CTINFO_MODE_CPMARK) diff --git a/net/sched/act_gact.c b/net/sched/act_gact.c index 62d682b96b88..904ab3d457ef 100644 --- a/net/sched/act_gact.c +++ b/net/sched/act_gact.c @@ -18,6 +18,7 @@ #include <net/pkt_cls.h> #include <linux/tc_act/tc_gact.h> #include <net/tc_act/tc_gact.h> +#include <net/tc_wrapper.h> static struct tc_action_ops act_gact_ops; @@ -25,7 +26,7 @@ static struct tc_action_ops act_gact_ops; static int gact_net_rand(struct tcf_gact *gact) { smp_rmb(); /* coupled with smp_wmb() in tcf_gact_init() */ - if (prandom_u32_max(gact->tcfg_pval)) + if (get_random_u32_below(gact->tcfg_pval)) return gact->tcf_action; return gact->tcfg_paction; } @@ -145,8 +146,9 @@ release_idr: return err; } -static int tcf_gact_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_gact_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_gact *gact = to_gact(a); int action = READ_ONCE(gact->tcf_action); diff --git a/net/sched/act_gate.c b/net/sched/act_gate.c index 3049878e7315..9b8def0be41e 100644 --- a/net/sched/act_gate.c +++ b/net/sched/act_gate.c @@ -14,6 +14,7 @@ #include <net/netlink.h> #include <net/pkt_cls.h> #include <net/tc_act/tc_gate.h> +#include <net/tc_wrapper.h> static struct tc_action_ops act_gate_ops; @@ -113,8 +114,9 @@ static enum hrtimer_restart gate_timer_func(struct hrtimer *timer) return HRTIMER_RESTART; } -static int tcf_gate_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_gate_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_gate *gact = to_gate(a); diff --git a/net/sched/act_ife.c b/net/sched/act_ife.c index 41d63b33461d..bc7611b0744c 100644 --- a/net/sched/act_ife.c +++ b/net/sched/act_ife.c @@ -29,6 +29,7 @@ #include <net/tc_act/tc_ife.h> #include <linux/etherdevice.h> #include <net/ife.h> +#include <net/tc_wrapper.h> static int max_metacnt = IFE_META_MAX + 1; static struct tc_action_ops act_ife_ops; @@ -861,8 +862,9 @@ static int tcf_ife_encode(struct sk_buff *skb, const struct tc_action *a, return action; } -static int tcf_ife_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_ife_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_ife_info *ife = to_ife(a); struct tcf_ife_params *p; diff --git a/net/sched/act_ipt.c b/net/sched/act_ipt.c index 1625e1037416..5d96ffebd40f 100644 --- a/net/sched/act_ipt.c +++ b/net/sched/act_ipt.c @@ -20,6 +20,7 @@ #include <net/pkt_sched.h> #include <linux/tc_act/tc_ipt.h> #include <net/tc_act/tc_ipt.h> +#include <net/tc_wrapper.h> #include <linux/netfilter_ipv4/ip_tables.h> @@ -216,8 +217,9 @@ static int tcf_xt_init(struct net *net, struct nlattr *nla, a, &act_xt_ops, tp, flags); } -static int tcf_ipt_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_ipt_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { int ret = 0, result = 0; struct tcf_ipt *ipt = to_ipt(a); diff --git a/net/sched/act_mirred.c b/net/sched/act_mirred.c index b8ad6ae282c0..7284bcea7b0b 100644 --- a/net/sched/act_mirred.c +++ b/net/sched/act_mirred.c @@ -24,6 +24,7 @@ #include <net/pkt_cls.h> #include <linux/tc_act/tc_mirred.h> #include <net/tc_act/tc_mirred.h> +#include <net/tc_wrapper.h> static LIST_HEAD(mirred_list); static DEFINE_SPINLOCK(mirred_list_lock); @@ -217,8 +218,9 @@ static int tcf_mirred_forward(bool want_ingress, struct sk_buff *skb) return err; } -static int tcf_mirred_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_mirred_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_mirred *m = to_mirred(a); struct sk_buff *skb2 = skb; diff --git a/net/sched/act_mpls.c b/net/sched/act_mpls.c index 8ad25cc8ccd5..6b26bdb999d7 100644 --- a/net/sched/act_mpls.c +++ b/net/sched/act_mpls.c @@ -14,6 +14,7 @@ #include <net/pkt_sched.h> #include <net/pkt_cls.h> #include <net/tc_act/tc_mpls.h> +#include <net/tc_wrapper.h> static struct tc_action_ops act_mpls_ops; @@ -49,8 +50,9 @@ static __be32 tcf_mpls_get_lse(struct mpls_shim_hdr *lse, return cpu_to_be32(new_lse); } -static int tcf_mpls_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_mpls_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_mpls *m = to_mpls(a); struct tcf_mpls_params *p; @@ -132,6 +134,11 @@ static int valid_label(const struct nlattr *attr, { const u32 *label = nla_data(attr); + if (nla_len(attr) != sizeof(*label)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid MPLS label length"); + return -EINVAL; + } + if (*label & ~MPLS_LABEL_MASK || *label == MPLS_LABEL_IMPLNULL) { NL_SET_ERR_MSG_MOD(extack, "MPLS label out of range"); return -EINVAL; @@ -143,7 +150,8 @@ static int valid_label(const struct nlattr *attr, static const struct nla_policy mpls_policy[TCA_MPLS_MAX + 1] = { [TCA_MPLS_PARMS] = NLA_POLICY_EXACT_LEN(sizeof(struct tc_mpls)), [TCA_MPLS_PROTO] = { .type = NLA_U16 }, - [TCA_MPLS_LABEL] = NLA_POLICY_VALIDATE_FN(NLA_U32, valid_label), + [TCA_MPLS_LABEL] = NLA_POLICY_VALIDATE_FN(NLA_BINARY, + valid_label), [TCA_MPLS_TC] = NLA_POLICY_RANGE(NLA_U8, 0, 7), [TCA_MPLS_TTL] = NLA_POLICY_MIN(NLA_U8, 1), [TCA_MPLS_BOS] = NLA_POLICY_RANGE(NLA_U8, 0, 1), diff --git a/net/sched/act_nat.c b/net/sched/act_nat.c index 9265145f1040..74c74be33048 100644 --- a/net/sched/act_nat.c +++ b/net/sched/act_nat.c @@ -24,7 +24,7 @@ #include <net/tc_act/tc_nat.h> #include <net/tcp.h> #include <net/udp.h> - +#include <net/tc_wrapper.h> static struct tc_action_ops act_nat_ops; @@ -98,8 +98,9 @@ release_idr: return err; } -static int tcf_nat_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_nat_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_nat *p = to_tcf_nat(a); struct iphdr *iph; diff --git a/net/sched/act_pedit.c b/net/sched/act_pedit.c index 94ed5857ce67..a0378e9f0121 100644 --- a/net/sched/act_pedit.c +++ b/net/sched/act_pedit.c @@ -20,6 +20,7 @@ #include <net/tc_act/tc_pedit.h> #include <uapi/linux/tc_act/tc_pedit.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> static struct tc_action_ops act_pedit_ops; @@ -319,8 +320,9 @@ static int pedit_skb_hdr_offset(struct sk_buff *skb, return ret; } -static int tcf_pedit_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_pedit_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_pedit *p = to_pedit(a); u32 max_offset; diff --git a/net/sched/act_police.c b/net/sched/act_police.c index 0adb26e366a7..227cba58ce9f 100644 --- a/net/sched/act_police.c +++ b/net/sched/act_police.c @@ -19,6 +19,7 @@ #include <net/netlink.h> #include <net/pkt_cls.h> #include <net/tc_act/tc_police.h> +#include <net/tc_wrapper.h> /* Each policer is serialized by its individual spinlock */ @@ -242,8 +243,9 @@ static bool tcf_police_mtu_check(struct sk_buff *skb, u32 limit) return len <= limit; } -static int tcf_police_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_police_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_police *police = to_police(a); s64 now, toks, ppstoks = 0, ptoks = 0; diff --git a/net/sched/act_sample.c b/net/sched/act_sample.c index 7a25477f5d99..f7416b5598e0 100644 --- a/net/sched/act_sample.c +++ b/net/sched/act_sample.c @@ -20,6 +20,7 @@ #include <net/tc_act/tc_sample.h> #include <net/psample.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/if_arp.h> @@ -153,8 +154,9 @@ static bool tcf_sample_dev_ok_push(struct net_device *dev) } } -static int tcf_sample_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_sample_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_sample *s = to_sample(a); struct psample_group *psample_group; @@ -168,7 +170,7 @@ static int tcf_sample_act(struct sk_buff *skb, const struct tc_action *a, psample_group = rcu_dereference_bh(s->psample_group); /* randomly sample packets according to rate */ - if (psample_group && (prandom_u32_max(s->rate) == 0)) { + if (psample_group && (get_random_u32_below(s->rate) == 0)) { if (!skb_at_tc_ingress(skb)) { md.in_ifindex = skb->skb_iif; md.out_ifindex = skb->dev->ifindex; diff --git a/net/sched/act_simple.c b/net/sched/act_simple.c index 18d376135461..4b84514534f3 100644 --- a/net/sched/act_simple.c +++ b/net/sched/act_simple.c @@ -14,6 +14,7 @@ #include <net/netlink.h> #include <net/pkt_sched.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/tc_act/tc_defact.h> #include <net/tc_act/tc_defact.h> @@ -21,8 +22,9 @@ static struct tc_action_ops act_simp_ops; #define SIMP_MAX_DATA 32 -static int tcf_simp_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_simp_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_defact *d = to_defact(a); diff --git a/net/sched/act_skbedit.c b/net/sched/act_skbedit.c index 7f598784fd30..ce7008cf291c 100644 --- a/net/sched/act_skbedit.c +++ b/net/sched/act_skbedit.c @@ -16,6 +16,7 @@ #include <net/ipv6.h> #include <net/dsfield.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/tc_act/tc_skbedit.h> #include <net/tc_act/tc_skbedit.h> @@ -36,8 +37,9 @@ static u16 tcf_skbedit_hash(struct tcf_skbedit_params *params, return netdev_cap_txqueue(skb->dev, queue_mapping); } -static int tcf_skbedit_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_skbedit_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_skbedit *d = to_skbedit(a); struct tcf_skbedit_params *params; @@ -148,6 +150,11 @@ static int tcf_skbedit_init(struct net *net, struct nlattr *nla, } if (tb[TCA_SKBEDIT_QUEUE_MAPPING] != NULL) { + if (is_tcf_skbedit_ingress(act_flags) && + !(act_flags & TCA_ACT_FLAGS_SKIP_SW)) { + NL_SET_ERR_MSG_MOD(extack, "\"queue_mapping\" option on receive side is hardware only, use skip_sw"); + return -EOPNOTSUPP; + } flags |= SKBEDIT_F_QUEUE_MAPPING; queue_mapping = nla_data(tb[TCA_SKBEDIT_QUEUE_MAPPING]); } @@ -374,9 +381,12 @@ static int tcf_skbedit_offload_act_setup(struct tc_action *act, void *entry_data } else if (is_tcf_skbedit_priority(act)) { entry->id = FLOW_ACTION_PRIORITY; entry->priority = tcf_skbedit_priority(act); - } else if (is_tcf_skbedit_queue_mapping(act)) { - NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"queue_mapping\" option is used"); + } else if (is_tcf_skbedit_tx_queue_mapping(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"queue_mapping\" option is used on transmit side"); return -EOPNOTSUPP; + } else if (is_tcf_skbedit_rx_queue_mapping(act)) { + entry->id = FLOW_ACTION_RX_QUEUE_MAPPING; + entry->rx_queue = tcf_skbedit_rx_queue_mapping(act); } else if (is_tcf_skbedit_inheritdsfield(act)) { NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"inheritdsfield\" option is used"); return -EOPNOTSUPP; @@ -394,6 +404,8 @@ static int tcf_skbedit_offload_act_setup(struct tc_action *act, void *entry_data fl_action->id = FLOW_ACTION_PTYPE; else if (is_tcf_skbedit_priority(act)) fl_action->id = FLOW_ACTION_PRIORITY; + else if (is_tcf_skbedit_rx_queue_mapping(act)) + fl_action->id = FLOW_ACTION_RX_QUEUE_MAPPING; else return -EOPNOTSUPP; } diff --git a/net/sched/act_skbmod.c b/net/sched/act_skbmod.c index d98758a63934..dffa990a9629 100644 --- a/net/sched/act_skbmod.c +++ b/net/sched/act_skbmod.c @@ -15,14 +15,16 @@ #include <net/netlink.h> #include <net/pkt_sched.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/tc_act/tc_skbmod.h> #include <net/tc_act/tc_skbmod.h> static struct tc_action_ops act_skbmod_ops; -static int tcf_skbmod_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_skbmod_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_skbmod *d = to_skbmod(a); int action, max_edit_len, err; diff --git a/net/sched/act_tunnel_key.c b/net/sched/act_tunnel_key.c index 2691a3d8e451..2d12d2626415 100644 --- a/net/sched/act_tunnel_key.c +++ b/net/sched/act_tunnel_key.c @@ -16,14 +16,16 @@ #include <net/pkt_sched.h> #include <net/dst.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/tc_act/tc_tunnel_key.h> #include <net/tc_act/tc_tunnel_key.h> static struct tc_action_ops act_tunnel_key_ops; -static int tunnel_key_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tunnel_key_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_tunnel_key *t = to_tunnel_key(a); struct tcf_tunnel_key_params *params; diff --git a/net/sched/act_vlan.c b/net/sched/act_vlan.c index 7b24e898a3e6..0251442f5f29 100644 --- a/net/sched/act_vlan.c +++ b/net/sched/act_vlan.c @@ -12,14 +12,16 @@ #include <net/netlink.h> #include <net/pkt_sched.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <linux/tc_act/tc_vlan.h> #include <net/tc_act/tc_vlan.h> static struct tc_action_ops act_vlan_ops; -static int tcf_vlan_act(struct sk_buff *skb, const struct tc_action *a, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcf_vlan_act(struct sk_buff *skb, + const struct tc_action *a, + struct tcf_result *res) { struct tcf_vlan *v = to_vlan(a); struct tcf_vlan_params *p; diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index 50566db45949..668130f08903 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -40,6 +40,7 @@ #include <net/tc_act/tc_mpls.h> #include <net/tc_act/tc_gate.h> #include <net/flow_offload.h> +#include <net/tc_wrapper.h> extern const struct nla_policy rtm_tca_policy[TCA_MAX + 1]; @@ -1564,7 +1565,7 @@ reclassify: tp->protocol != htons(ETH_P_ALL)) continue; - err = tp->classify(skb, tp, res); + err = tc_classify(skb, tp, res); #ifdef CONFIG_NET_CLS_ACT if (unlikely(err == TC_ACT_RECLASSIFY && !compat_mode)) { first_tp = orig_tp; @@ -1953,6 +1954,11 @@ static void tfilter_put(struct tcf_proto *tp, void *fh) tp->ops->put(tp, fh); } +static bool is_qdisc_ingress(__u32 classid) +{ + return (TC_H_MIN(classid) == TC_H_MIN(TC_H_MIN_INGRESS)); +} + static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, struct netlink_ext_ack *extack) { @@ -2144,6 +2150,8 @@ replay: flags |= TCA_ACT_FLAGS_REPLACE; if (!rtnl_held) flags |= TCA_ACT_FLAGS_NO_RTNL; + if (is_qdisc_ingress(parent)) + flags |= TCA_ACT_FLAGS_AT_INGRESS; err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh, flags, extack); if (err == 0) { diff --git a/net/sched/cls_basic.c b/net/sched/cls_basic.c index d229ce99e554..1b92c33b5f81 100644 --- a/net/sched/cls_basic.c +++ b/net/sched/cls_basic.c @@ -18,6 +18,7 @@ #include <net/netlink.h> #include <net/act_api.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> struct basic_head { struct list_head flist; @@ -36,8 +37,9 @@ struct basic_filter { struct rcu_work rwork; }; -static int basic_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int basic_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { int r; struct basic_head *head = rcu_dereference_bh(tp->root); diff --git a/net/sched/cls_bpf.c b/net/sched/cls_bpf.c index bc317b3eac12..466c26df853a 100644 --- a/net/sched/cls_bpf.c +++ b/net/sched/cls_bpf.c @@ -19,6 +19,7 @@ #include <net/rtnetlink.h> #include <net/pkt_cls.h> #include <net/sock.h> +#include <net/tc_wrapper.h> MODULE_LICENSE("GPL"); MODULE_AUTHOR("Daniel Borkmann <dborkman@redhat.com>"); @@ -77,8 +78,9 @@ static int cls_bpf_exec_opcode(int code) } } -static int cls_bpf_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int cls_bpf_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct cls_bpf_head *head = rcu_dereference_bh(tp->root); bool at_ingress = skb_at_tc_ingress(skb); diff --git a/net/sched/cls_cgroup.c b/net/sched/cls_cgroup.c index ed00001b528a..bd9322d71910 100644 --- a/net/sched/cls_cgroup.c +++ b/net/sched/cls_cgroup.c @@ -13,6 +13,7 @@ #include <net/pkt_cls.h> #include <net/sock.h> #include <net/cls_cgroup.h> +#include <net/tc_wrapper.h> struct cls_cgroup_head { u32 handle; @@ -22,8 +23,9 @@ struct cls_cgroup_head { struct rcu_work rwork; }; -static int cls_cgroup_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int cls_cgroup_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct cls_cgroup_head *head = rcu_dereference_bh(tp->root); u32 classid = task_get_classid(skb); diff --git a/net/sched/cls_flow.c b/net/sched/cls_flow.c index 014cd3de7b5d..6ab317b48d6c 100644 --- a/net/sched/cls_flow.c +++ b/net/sched/cls_flow.c @@ -24,6 +24,7 @@ #include <net/ip.h> #include <net/route.h> #include <net/flow_dissector.h> +#include <net/tc_wrapper.h> #if IS_ENABLED(CONFIG_NF_CONNTRACK) #include <net/netfilter/nf_conntrack.h> @@ -292,8 +293,9 @@ static u32 flow_key_get(struct sk_buff *skb, int key, struct flow_keys *flow) (1 << FLOW_KEY_NFCT_PROTO_SRC) | \ (1 << FLOW_KEY_NFCT_PROTO_DST)) -static int flow_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int flow_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct flow_head *head = rcu_dereference_bh(tp->root); struct flow_filter *f; @@ -367,7 +369,7 @@ static const struct nla_policy flow_policy[TCA_FLOW_MAX + 1] = { static void __flow_destroy_filter(struct flow_filter *f) { - del_timer_sync(&f->perturb_timer); + timer_shutdown_sync(&f->perturb_timer); tcf_exts_destroy(&f->exts); tcf_em_tree_destroy(&f->ematches); tcf_exts_put_net(&f->exts); diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c index 25bc57ee6ea1..0b15698b3531 100644 --- a/net/sched/cls_flower.c +++ b/net/sched/cls_flower.c @@ -27,6 +27,7 @@ #include <net/vxlan.h> #include <net/erspan.h> #include <net/gtp.h> +#include <net/tc_wrapper.h> #include <net/dst.h> #include <net/dst_metadata.h> @@ -305,8 +306,9 @@ static u16 fl_ct_info_to_flower_map[] = { TCA_FLOWER_KEY_CT_FLAGS_NEW, }; -static int fl_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int fl_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct cls_fl_head *head = rcu_dereference_bh(tp->root); bool post_ct = tc_skb_cb(skb)->post_ct; diff --git a/net/sched/cls_fw.c b/net/sched/cls_fw.c index a32351da968c..ae9439a6c56c 100644 --- a/net/sched/cls_fw.c +++ b/net/sched/cls_fw.c @@ -21,6 +21,7 @@ #include <net/act_api.h> #include <net/pkt_cls.h> #include <net/sch_generic.h> +#include <net/tc_wrapper.h> #define HTSIZE 256 @@ -47,8 +48,9 @@ static u32 fw_hash(u32 handle) return handle % HTSIZE; } -static int fw_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int fw_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct fw_head *head = rcu_dereference_bh(tp->root); struct fw_filter *f; diff --git a/net/sched/cls_matchall.c b/net/sched/cls_matchall.c index 39a5d9c170de..705f63da2c21 100644 --- a/net/sched/cls_matchall.c +++ b/net/sched/cls_matchall.c @@ -12,6 +12,7 @@ #include <net/sch_generic.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> struct cls_mall_head { struct tcf_exts exts; @@ -24,8 +25,9 @@ struct cls_mall_head { bool deleting; }; -static int mall_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int mall_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct cls_mall_head *head = rcu_dereference_bh(tp->root); diff --git a/net/sched/cls_route.c b/net/sched/cls_route.c index 9e43b929d4ca..d0c53724d3e8 100644 --- a/net/sched/cls_route.c +++ b/net/sched/cls_route.c @@ -17,6 +17,7 @@ #include <net/netlink.h> #include <net/act_api.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> /* * 1. For now we assume that route tags < 256. @@ -121,8 +122,9 @@ static inline int route4_hash_wild(void) return 0; \ } -static int route4_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int route4_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct route4_head *head = rcu_dereference_bh(tp->root); struct dst_entry *dst; diff --git a/net/sched/cls_rsvp.c b/net/sched/cls_rsvp.c index de1c1d4da597..03d8619bd9c6 100644 --- a/net/sched/cls_rsvp.c +++ b/net/sched/cls_rsvp.c @@ -15,10 +15,12 @@ #include <net/netlink.h> #include <net/act_api.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #define RSVP_DST_LEN 1 #define RSVP_ID "rsvp" #define RSVP_OPS cls_rsvp_ops +#define RSVP_CLS rsvp_classify #include "cls_rsvp.h" MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_rsvp.h b/net/sched/cls_rsvp.h index b00a7dbd0587..869efba9f834 100644 --- a/net/sched/cls_rsvp.h +++ b/net/sched/cls_rsvp.h @@ -124,8 +124,8 @@ static inline unsigned int hash_src(__be32 *src) return r; \ } -static int rsvp_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int RSVP_CLS(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) { struct rsvp_head *head = rcu_dereference_bh(tp->root); struct rsvp_session *s; @@ -738,7 +738,7 @@ static void rsvp_bind_class(void *fh, u32 classid, unsigned long cl, void *q, static struct tcf_proto_ops RSVP_OPS __read_mostly = { .kind = RSVP_ID, - .classify = rsvp_classify, + .classify = RSVP_CLS, .init = rsvp_init, .destroy = rsvp_destroy, .get = rsvp_get, diff --git a/net/sched/cls_rsvp6.c b/net/sched/cls_rsvp6.c index 64078846000e..e627cc32d633 100644 --- a/net/sched/cls_rsvp6.c +++ b/net/sched/cls_rsvp6.c @@ -15,10 +15,12 @@ #include <net/act_api.h> #include <net/pkt_cls.h> #include <net/netlink.h> +#include <net/tc_wrapper.h> #define RSVP_DST_LEN 4 #define RSVP_ID "rsvp6" #define RSVP_OPS cls_rsvp6_ops +#define RSVP_CLS rsvp6_classify #include "cls_rsvp.h" MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_tcindex.c b/net/sched/cls_tcindex.c index 1c9eeb98d826..ee2a050c887b 100644 --- a/net/sched/cls_tcindex.c +++ b/net/sched/cls_tcindex.c @@ -16,6 +16,7 @@ #include <net/netlink.h> #include <net/pkt_cls.h> #include <net/sch_generic.h> +#include <net/tc_wrapper.h> /* * Passing parameters to the root seems to be done more awkwardly than really @@ -98,9 +99,9 @@ static struct tcindex_filter_result *tcindex_lookup(struct tcindex_data *p, return NULL; } - -static int tcindex_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int tcindex_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct tcindex_data *p = rcu_dereference_bh(tp->root); struct tcindex_filter_result *f; @@ -332,7 +333,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, struct tcindex_filter_result *r, struct nlattr **tb, struct nlattr *est, u32 flags, struct netlink_ext_ack *extack) { - struct tcindex_filter_result new_filter_result, *old_r = r; + struct tcindex_filter_result new_filter_result; struct tcindex_data *cp = NULL, *oldp; struct tcindex_filter *f = NULL; /* make gcc behave */ struct tcf_result cr = {}; @@ -401,7 +402,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, err = tcindex_filter_result_init(&new_filter_result, cp, net); if (err < 0) goto errout_alloc; - if (old_r) + if (r) cr = r->res; err = -EBUSY; @@ -478,14 +479,6 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, tcf_bind_filter(tp, &cr, base); } - if (old_r && old_r != r) { - err = tcindex_filter_result_init(old_r, cp, net); - if (err < 0) { - kfree(f); - goto errout_alloc; - } - } - oldp = p; r->res = cr; tcf_exts_change(&r->exts, &e); diff --git a/net/sched/cls_u32.c b/net/sched/cls_u32.c index 34d25f7a0687..4e2e269f121f 100644 --- a/net/sched/cls_u32.c +++ b/net/sched/cls_u32.c @@ -39,6 +39,7 @@ #include <net/act_api.h> #include <net/pkt_cls.h> #include <linux/idr.h> +#include <net/tc_wrapper.h> struct tc_u_knode { struct tc_u_knode __rcu *next; @@ -100,8 +101,9 @@ static inline unsigned int u32_hash_fold(__be32 key, return h; } -static int u32_classify(struct sk_buff *skb, const struct tcf_proto *tp, - struct tcf_result *res) +TC_INDIRECT_SCOPE int u32_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + struct tcf_result *res) { struct { struct tc_u_knode *knode; diff --git a/net/sched/ematch.c b/net/sched/ematch.c index 4ce681361851..5c1235e6076a 100644 --- a/net/sched/ematch.c +++ b/net/sched/ematch.c @@ -255,6 +255,8 @@ static int tcf_em_validate(struct tcf_proto *tp, * the value carried. */ if (em_hdr->flags & TCF_EM_SIMPLE) { + if (em->ops->datalen > 0) + goto errout; if (data_len < sizeof(u32)) goto errout; em->data = *(u32 *) data; diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c index 4a27dfb1ba0f..72d2c204d5f3 100644 --- a/net/sched/sch_api.c +++ b/net/sched/sch_api.c @@ -31,6 +31,7 @@ #include <net/netlink.h> #include <net/pkt_sched.h> #include <net/pkt_cls.h> +#include <net/tc_wrapper.h> #include <trace/events/qdisc.h> @@ -1132,6 +1133,11 @@ skip: return -ENOENT; } + if (new && new->ops == &noqueue_qdisc_ops) { + NL_SET_ERR_MSG(extack, "Cannot assign noqueue to a class"); + return -EINVAL; + } + err = cops->graft(parent, cl, new, &old, extack); if (err) return err; @@ -2273,6 +2279,8 @@ static struct pernet_operations psched_net_ops = { .exit = psched_net_exit, }; +DEFINE_STATIC_KEY_FALSE(tc_skip_wrapper); + static int __init pktsched_init(void) { int err; @@ -2300,6 +2308,8 @@ static int __init pktsched_init(void) rtnl_register(PF_UNSPEC, RTM_GETTCLASS, tc_ctl_tclass, tc_dump_tclass, 0); + tc_wrapper_init(); + return 0; } diff --git a/net/sched/sch_atm.c b/net/sched/sch_atm.c index f52255fea652..4a981ca90b0b 100644 --- a/net/sched/sch_atm.c +++ b/net/sched/sch_atm.c @@ -393,10 +393,13 @@ static int atm_tc_enqueue(struct sk_buff *skb, struct Qdisc *sch, result = tcf_classify(skb, NULL, fl, &res, true); if (result < 0) continue; + if (result == TC_ACT_SHOT) + goto done; + flow = (struct atm_flow_data *)res.class; if (!flow) flow = lookup_flow(sch, res.classid); - goto done; + goto drop; } } flow = NULL; diff --git a/net/sched/sch_cbq.c b/net/sched/sch_cbq.c index 6568e17c4c63..36db5f6782f2 100644 --- a/net/sched/sch_cbq.c +++ b/net/sched/sch_cbq.c @@ -230,6 +230,8 @@ cbq_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) result = tcf_classify(skb, NULL, fl, &res, true); if (!fl || result < 0) goto fallback; + if (result == TC_ACT_SHOT) + return NULL; cl = (void *)res.class; if (!cl) { @@ -250,8 +252,6 @@ cbq_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) case TC_ACT_TRAP: *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; fallthrough; - case TC_ACT_SHOT: - return NULL; case TC_ACT_RECLASSIFY: return cbq_reclassify(skb, cl); } diff --git a/net/sched/sch_choke.c b/net/sched/sch_choke.c index 3ac3e5c80b6f..19c851125901 100644 --- a/net/sched/sch_choke.c +++ b/net/sched/sch_choke.c @@ -183,7 +183,7 @@ static struct sk_buff *choke_peek_random(const struct choke_sched_data *q, int retrys = 3; do { - *pidx = (q->head + prandom_u32_max(choke_len(q))) & q->tab_mask; + *pidx = (q->head + get_random_u32_below(choke_len(q))) & q->tab_mask; skb = q->tab[*pidx]; if (skb) return skb; diff --git a/net/sched/sch_gred.c b/net/sched/sch_gred.c index a661b062cca8..872d127c9db4 100644 --- a/net/sched/sch_gred.c +++ b/net/sched/sch_gred.c @@ -377,6 +377,7 @@ static int gred_offload_dump_stats(struct Qdisc *sch) /* Even if driver returns failure adjust the stats - in case offload * ended but driver still wants to adjust the values. */ + sch_tree_lock(sch); for (i = 0; i < MAX_DPs; i++) { if (!table->tab[i]) continue; @@ -393,6 +394,7 @@ static int gred_offload_dump_stats(struct Qdisc *sch) sch->qstats.overlimits += hw_stats->stats.qstats[i].overlimits; } _bstats_update(&sch->bstats, bytes, packets); + sch_tree_unlock(sch); kfree(hw_stats); return ret; diff --git a/net/sched/sch_htb.c b/net/sched/sch_htb.c index e5b4bbf3ce3d..cc28e41fb745 100644 --- a/net/sched/sch_htb.c +++ b/net/sched/sch_htb.c @@ -199,8 +199,14 @@ static unsigned long htb_search(struct Qdisc *sch, u32 handle) { return (unsigned long)htb_find(handle, sch); } + +#define HTB_DIRECT ((struct htb_class *)-1L) + /** * htb_classify - classify a packet into class + * @skb: the socket buffer + * @sch: the active queue discipline + * @qerr: pointer for returned status code * * It returns NULL if the packet should be dropped or -1 if the packet * should be passed directly thru. In all other cases leaf class is returned. @@ -211,8 +217,6 @@ static unsigned long htb_search(struct Qdisc *sch, u32 handle) * have no valid leaf we try to use MAJOR:default leaf. It still unsuccessful * then finish and return direct queue. */ -#define HTB_DIRECT ((struct htb_class *)-1L) - static struct htb_class *htb_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) { @@ -427,7 +431,10 @@ static void htb_activate_prios(struct htb_sched *q, struct htb_class *cl) while (cl->cmode == HTB_MAY_BORROW && p && mask) { m = mask; while (m) { - int prio = ffz(~m); + unsigned int prio = ffz(~m); + + if (WARN_ON_ONCE(prio > ARRAY_SIZE(p->inner.clprio))) + break; m &= ~(1 << prio); if (p->inner.clprio[prio].feed.rb_node) @@ -1545,7 +1552,7 @@ static int htb_destroy_class_offload(struct Qdisc *sch, struct htb_class *cl, struct tc_htb_qopt_offload offload_opt; struct netdev_queue *dev_queue; struct Qdisc *q = cl->leaf.q; - struct Qdisc *old = NULL; + struct Qdisc *old; int err; if (cl->level) @@ -1553,14 +1560,17 @@ static int htb_destroy_class_offload(struct Qdisc *sch, struct htb_class *cl, WARN_ON(!q); dev_queue = htb_offload_get_queue(cl); - old = htb_graft_helper(dev_queue, NULL); - if (destroying) - /* Before HTB is destroyed, the kernel grafts noop_qdisc to - * all queues. + /* When destroying, caller qdisc_graft grafts the new qdisc and invokes + * qdisc_put for the qdisc being destroyed. htb_destroy_class_offload + * does not need to graft or qdisc_put the qdisc being destroyed. + */ + if (!destroying) { + old = htb_graft_helper(dev_queue, NULL); + /* Last qdisc grafted should be the same as cl->leaf.q when + * calling htb_delete. */ - WARN_ON(!(old->flags & TCQ_F_BUILTIN)); - else WARN_ON(old != q); + } if (cl->parent) { _bstats_update(&cl->parent->bstats_bias, @@ -1577,10 +1587,12 @@ static int htb_destroy_class_offload(struct Qdisc *sch, struct htb_class *cl, }; err = htb_offload(qdisc_dev(sch), &offload_opt); - if (!err || destroying) - qdisc_put(old); - else - htb_graft_helper(dev_queue, old); + if (!destroying) { + if (!err) + qdisc_put(old); + else + htb_graft_helper(dev_queue, old); + } if (last_child) return err; diff --git a/net/sched/sch_netem.c b/net/sched/sch_netem.c index fb00ac40ecb7..6ef3021e1169 100644 --- a/net/sched/sch_netem.c +++ b/net/sched/sch_netem.c @@ -513,8 +513,8 @@ static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, goto finish_segs; } - skb->data[prandom_u32_max(skb_headlen(skb))] ^= - 1<<prandom_u32_max(8); + skb->data[get_random_u32_below(skb_headlen(skb))] ^= + 1<<get_random_u32_below(8); } if (unlikely(sch->q.qlen >= sch->limit)) { diff --git a/net/sched/sch_taprio.c b/net/sched/sch_taprio.c index 570389f6cdd7..c322a61eaeea 100644 --- a/net/sched/sch_taprio.c +++ b/net/sched/sch_taprio.c @@ -1700,6 +1700,7 @@ static void taprio_reset(struct Qdisc *sch) int i; hrtimer_cancel(&q->advance_timer); + if (q->qdiscs) { for (i = 0; i < dev->num_tx_queues; i++) if (q->qdiscs[i]) @@ -1720,6 +1721,7 @@ static void taprio_destroy(struct Qdisc *sch) * happens in qdisc_create(), after taprio_init() has been called. */ hrtimer_cancel(&q->advance_timer); + qdisc_synchronize(sch); taprio_disable_offload(dev, q, NULL); diff --git a/net/sctp/associola.c b/net/sctp/associola.c index 3460abceba44..63ba5551c13f 100644 --- a/net/sctp/associola.c +++ b/net/sctp/associola.c @@ -226,8 +226,7 @@ static struct sctp_association *sctp_association_init( /* Create an output queue. */ sctp_outq_init(asoc, &asoc->outqueue); - if (!sctp_ulpq_init(&asoc->ulpq, asoc)) - goto fail_init; + sctp_ulpq_init(&asoc->ulpq, asoc); if (sctp_stream_init(&asoc->stream, asoc->c.sinit_num_ostreams, 0, gfp)) goto stream_free; @@ -277,7 +276,6 @@ static struct sctp_association *sctp_association_init( stream_free: sctp_stream_free(&asoc->stream); -fail_init: sock_put(asoc->base.sk); sctp_endpoint_put(asoc->ep); return NULL; diff --git a/net/sctp/bind_addr.c b/net/sctp/bind_addr.c index 59e653b528b1..6b95d3ba8fe1 100644 --- a/net/sctp/bind_addr.c +++ b/net/sctp/bind_addr.c @@ -73,6 +73,12 @@ int sctp_bind_addr_copy(struct net *net, struct sctp_bind_addr *dest, } } + /* If somehow no addresses were found that can be used with this + * scope, it's an error. + */ + if (list_empty(&dest->address_list)) + error = -ENETUNREACH; + out: if (error) sctp_bind_addr_clean(dest); diff --git a/net/sctp/diag.c b/net/sctp/diag.c index d9c6d8f30f09..a557009e9832 100644 --- a/net/sctp/diag.c +++ b/net/sctp/diag.c @@ -426,6 +426,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb, struct net *net = sock_net(skb->sk); const struct nlmsghdr *nlh = cb->nlh; union sctp_addr laddr, paddr; + int dif = req->id.idiag_if; struct sctp_comm_param commp = { .skb = skb, .r = req, @@ -454,7 +455,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb, } return sctp_transport_lookup_process(sctp_sock_dump_one, - net, &laddr, &paddr, &commp); + net, &laddr, &paddr, &commp, dif); } static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, diff --git a/net/sctp/endpointola.c b/net/sctp/endpointola.c index efffde7f2328..7e77b450697c 100644 --- a/net/sctp/endpointola.c +++ b/net/sctp/endpointola.c @@ -246,12 +246,15 @@ void sctp_endpoint_put(struct sctp_endpoint *ep) /* Is this the endpoint we are looking for? */ struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep, struct net *net, - const union sctp_addr *laddr) + const union sctp_addr *laddr, + int dif, int sdif) { + int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if); struct sctp_endpoint *retval = NULL; - if ((htons(ep->base.bind_addr.port) == laddr->v4.sin_port) && - net_eq(ep->base.net, net)) { + if (net_eq(ep->base.net, net) && + sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) && + (htons(ep->base.bind_addr.port) == laddr->v4.sin_port)) { if (sctp_bind_addr_match(&ep->base.bind_addr, laddr, sctp_sk(ep->base.sk))) retval = ep; @@ -298,6 +301,7 @@ out: bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, const union sctp_addr *paddr) { + int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if); struct sctp_sockaddr_entry *addr; struct net *net = ep->base.net; struct sctp_bind_addr *bp; @@ -307,7 +311,8 @@ bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, * so the address_list can not change. */ list_for_each_entry(addr, &bp->address_list, list) { - if (sctp_has_association(net, &addr->a, paddr)) + if (sctp_has_association(net, &addr->a, paddr, + bound_dev_if, bound_dev_if)) return true; } diff --git a/net/sctp/input.c b/net/sctp/input.c index 4f43afa8678f..bf70371301ff 100644 --- a/net/sctp/input.c +++ b/net/sctp/input.c @@ -50,16 +50,19 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *paddr, const union sctp_addr *laddr, - struct sctp_transport **transportp); + struct sctp_transport **transportp, + int dif, int sdif); static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - const union sctp_addr *daddr); + const union sctp_addr *daddr, + int dif, int sdif); static struct sctp_association *__sctp_lookup_association( struct net *net, const union sctp_addr *local, const union sctp_addr *peer, - struct sctp_transport **pt); + struct sctp_transport **pt, + int dif, int sdif); static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb); @@ -92,11 +95,11 @@ int sctp_rcv(struct sk_buff *skb) struct sctp_chunk *chunk; union sctp_addr src; union sctp_addr dest; - int bound_dev_if; int family; struct sctp_af *af; struct net *net = dev_net(skb->dev); bool is_gso = skb_is_gso(skb) && skb_is_gso_sctp(skb); + int dif, sdif; if (skb->pkt_type != PACKET_HOST) goto discard_it; @@ -141,6 +144,8 @@ int sctp_rcv(struct sk_buff *skb) /* Initialize local addresses for lookups. */ af->from_skb(&src, skb, 1); af->from_skb(&dest, skb, 0); + dif = af->skb_iif(skb); + sdif = af->skb_sdif(skb); /* If the packet is to or from a non-unicast address, * silently discard the packet. @@ -157,36 +162,16 @@ int sctp_rcv(struct sk_buff *skb) !af->addr_valid(&dest, NULL, skb)) goto discard_it; - asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport); + asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport, dif, sdif); if (!asoc) - ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src); + ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src, dif, sdif); /* Retrieve the common input handling substructure. */ rcvr = asoc ? &asoc->base : &ep->base; sk = rcvr->sk; /* - * If a frame arrives on an interface and the receiving socket is - * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB - */ - bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); - if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) { - if (transport) { - sctp_transport_put(transport); - asoc = NULL; - transport = NULL; - } else { - sctp_endpoint_put(ep); - ep = NULL; - } - sk = net->sctp.ctl_sock; - ep = sctp_sk(sk)->ep; - sctp_endpoint_hold(ep); - rcvr = &ep->base; - } - - /* * RFC 2960, 8.4 - Handle "Out of the blue" Packets. * An SCTP packet is called an "out of the blue" (OOTB) * packet if it is correctly formed, i.e., passed the @@ -485,6 +470,8 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb, struct sctp_association *asoc; struct sctp_transport *transport = NULL; __u32 vtag = ntohl(sctphdr->vtag); + int sdif = inet_sdif(skb); + int dif = inet_iif(skb); *app = NULL; *tpp = NULL; @@ -500,7 +487,7 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb, /* Look for an association that matches the incoming ICMP error * packet. */ - asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport); + asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport, dif, sdif); if (!asoc) return NULL; @@ -850,7 +837,8 @@ static inline __u32 sctp_hashfn(const struct net *net, __be16 lport, static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct sctp_hashbucket *head; struct sctp_endpoint *ep; @@ -863,7 +851,7 @@ static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( head = &sctp_ep_hashtable[hash]; read_lock(&head->lock); sctp_for_each_hentry(ep, &head->chain) { - if (sctp_endpoint_is_match(ep, net, laddr)) + if (sctp_endpoint_is_match(ep, net, laddr, dif, sdif)) goto hit; } @@ -990,14 +978,26 @@ void sctp_unhash_transport(struct sctp_transport *t) sctp_hash_params); } +bool sctp_sk_bound_dev_eq(struct net *net, int bound_dev_if, int dif, int sdif) +{ + bool l3mdev_accept = true; + +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + l3mdev_accept = !!READ_ONCE(net->sctp.l3mdev_accept); +#endif + return inet_bound_dev_eq(l3mdev_accept, bound_dev_if, dif, sdif); +} + /* return a transport with holding it */ struct sctp_transport *sctp_addrs_lookup_transport( struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct rhlist_head *tmp, *list; struct sctp_transport *t; + int bound_dev_if; struct sctp_hash_cmp_arg arg = { .paddr = paddr, .net = net, @@ -1011,7 +1011,9 @@ struct sctp_transport *sctp_addrs_lookup_transport( if (!sctp_transport_hold(t)) continue; - if (sctp_bind_addr_match(&t->asoc->base.bind_addr, + bound_dev_if = READ_ONCE(t->asoc->base.sk->sk_bound_dev_if); + if (sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) && + sctp_bind_addr_match(&t->asoc->base.bind_addr, laddr, sctp_sk(t->asoc->base.sk))) return t; sctp_transport_put(t); @@ -1048,12 +1050,13 @@ static struct sctp_association *__sctp_lookup_association( struct net *net, const union sctp_addr *local, const union sctp_addr *peer, - struct sctp_transport **pt) + struct sctp_transport **pt, + int dif, int sdif) { struct sctp_transport *t; struct sctp_association *asoc = NULL; - t = sctp_addrs_lookup_transport(net, local, peer); + t = sctp_addrs_lookup_transport(net, local, peer, dif, sdif); if (!t) goto out; @@ -1069,12 +1072,13 @@ static struct sctp_association *sctp_lookup_association(struct net *net, const union sctp_addr *laddr, const union sctp_addr *paddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; rcu_read_lock(); - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); rcu_read_unlock(); return asoc; @@ -1083,11 +1087,12 @@ struct sctp_association *sctp_lookup_association(struct net *net, /* Is there an association matching the given local and peer addresses? */ bool sctp_has_association(struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr) + const union sctp_addr *paddr, + int dif, int sdif) { struct sctp_transport *transport; - if (sctp_lookup_association(net, laddr, paddr, &transport)) { + if (sctp_lookup_association(net, laddr, paddr, &transport, dif, sdif)) { sctp_transport_put(transport); return true; } @@ -1115,7 +1120,8 @@ bool sctp_has_association(struct net *net, */ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net, struct sk_buff *skb, - const union sctp_addr *laddr, struct sctp_transport **transportp) + const union sctp_addr *laddr, struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; union sctp_addr addr; @@ -1154,7 +1160,7 @@ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net, if (!af->from_addr_param(paddr, params.addr, sh->source, 0)) continue; - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); if (asoc) return asoc; } @@ -1181,7 +1187,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( struct sctp_chunkhdr *ch, const union sctp_addr *laddr, __be16 peer_port, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_addip_chunk *asconf = (struct sctp_addip_chunk *)ch; struct sctp_af *af; @@ -1201,7 +1208,7 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( if (!af->from_addr_param(&paddr, param, peer_port, 0)) return NULL; - return __sctp_lookup_association(net, laddr, &paddr, transportp); + return __sctp_lookup_association(net, laddr, &paddr, transportp, dif, sdif); } @@ -1217,7 +1224,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup( static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc = NULL; struct sctp_chunkhdr *ch; @@ -1260,7 +1268,7 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, asoc = __sctp_rcv_asconf_lookup( net, ch, laddr, sctp_hdr(skb)->source, - transportp); + transportp, dif, sdif); break; default: break; @@ -1285,7 +1293,8 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net, struct sk_buff *skb, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_chunkhdr *ch; @@ -1309,9 +1318,9 @@ static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net, /* If this is INIT/INIT-ACK look inside the chunk too. */ if (ch->type == SCTP_CID_INIT || ch->type == SCTP_CID_INIT_ACK) - return __sctp_rcv_init_lookup(net, skb, laddr, transportp); + return __sctp_rcv_init_lookup(net, skb, laddr, transportp, dif, sdif); - return __sctp_rcv_walk_lookup(net, skb, laddr, transportp); + return __sctp_rcv_walk_lookup(net, skb, laddr, transportp, dif, sdif); } /* Lookup an association for an inbound skb. */ @@ -1319,11 +1328,12 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, struct sk_buff *skb, const union sctp_addr *paddr, const union sctp_addr *laddr, - struct sctp_transport **transportp) + struct sctp_transport **transportp, + int dif, int sdif) { struct sctp_association *asoc; - asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif); if (asoc) goto out; @@ -1331,7 +1341,7 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net, * SCTP Implementors Guide, 2.18 Handling of address * parameters within the INIT or INIT-ACK. */ - asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp); + asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp, dif, sdif); if (asoc) goto out; diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c index d081858c2d07..097bd60ce964 100644 --- a/net/sctp/ipv6.c +++ b/net/sctp/ipv6.c @@ -680,9 +680,11 @@ static int sctp_v6_is_any(const union sctp_addr *addr) /* Should this be available for binding? */ static int sctp_v6_available(union sctp_addr *addr, struct sctp_sock *sp) { - int type; - struct net *net = sock_net(&sp->inet.sk); const struct in6_addr *in6 = (const struct in6_addr *)&addr->v6.sin6_addr; + struct sock *sk = &sp->inet.sk; + struct net *net = sock_net(sk); + struct net_device *dev = NULL; + int type; type = ipv6_addr_type(in6); if (IPV6_ADDR_ANY == type) @@ -696,8 +698,14 @@ static int sctp_v6_available(union sctp_addr *addr, struct sctp_sock *sp) if (!(type & IPV6_ADDR_UNICAST)) return 0; + if (sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(net, sk->sk_bound_dev_if); + if (!dev) + return 0; + } + return ipv6_can_nonlocal_bind(net, &sp->inet) || - ipv6_chk_addr(net, in6, NULL, 0); + ipv6_chk_addr(net, in6, dev, 0); } /* This function checks if the address is a valid address to be used for @@ -834,7 +842,12 @@ static int sctp_v6_addr_to_user(struct sctp_sock *sp, union sctp_addr *addr) /* Where did this skb come from? */ static int sctp_v6_skb_iif(const struct sk_buff *skb) { - return IP6CB(skb)->iif; + return inet6_iif(skb); +} + +static int sctp_v6_skb_sdif(const struct sk_buff *skb) +{ + return inet6_sdif(skb); } /* Was this packet marked by Explicit Congestion Notification? */ @@ -1134,6 +1147,7 @@ static struct sctp_af sctp_af_inet6 = { .is_any = sctp_v6_is_any, .available = sctp_v6_available, .skb_iif = sctp_v6_skb_iif, + .skb_sdif = sctp_v6_skb_sdif, .is_ce = sctp_v6_is_ce, .seq_dump_addr = sctp_v6_seq_dump_addr, .ecn_capable = sctp_v6_ecn_capable, diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c index bcd3384ab07a..909a89a1cff4 100644 --- a/net/sctp/protocol.c +++ b/net/sctp/protocol.c @@ -351,10 +351,13 @@ static int sctp_v4_addr_valid(union sctp_addr *addr, /* Should this be available for binding? */ static int sctp_v4_available(union sctp_addr *addr, struct sctp_sock *sp) { - struct net *net = sock_net(&sp->inet.sk); - int ret = inet_addr_type(net, addr->v4.sin_addr.s_addr); - + struct sock *sk = &sp->inet.sk; + struct net *net = sock_net(sk); + int tb_id = RT_TABLE_LOCAL; + int ret; + tb_id = l3mdev_fib_table_by_index(net, sk->sk_bound_dev_if) ?: tb_id; + ret = inet_addr_type_table(net, addr->v4.sin_addr.s_addr, tb_id); if (addr->v4.sin_addr.s_addr != htonl(INADDR_ANY) && ret != RTN_LOCAL && !sp->inet.freebind && @@ -564,6 +567,11 @@ static int sctp_v4_skb_iif(const struct sk_buff *skb) return inet_iif(skb); } +static int sctp_v4_skb_sdif(const struct sk_buff *skb) +{ + return inet_sdif(skb); +} + /* Was this packet marked by Explicit Congestion Notification? */ static int sctp_v4_is_ce(const struct sk_buff *skb) { @@ -1182,6 +1190,7 @@ static struct sctp_af sctp_af_inet = { .available = sctp_v4_available, .scope = sctp_v4_scope, .skb_iif = sctp_v4_skb_iif, + .skb_sdif = sctp_v4_skb_sdif, .is_ce = sctp_v4_is_ce, .seq_dump_addr = sctp_v4_seq_dump_addr, .ecn_capable = sctp_v4_ecn_capable, @@ -1385,6 +1394,10 @@ static int __net_init sctp_defaults_init(struct net *net) /* Initialize maximum autoclose timeout. */ net->sctp.max_autoclose = INT_MAX / HZ; +#ifdef CONFIG_NET_L3_MASTER_DEV + net->sctp.l3mdev_accept = 1; +#endif + status = sctp_sysctl_net_register(net); if (status) goto err_sysctl_register; diff --git a/net/sctp/sm_statefuns.c b/net/sctp/sm_statefuns.c index f6ee7f4040c1..ce5426171206 100644 --- a/net/sctp/sm_statefuns.c +++ b/net/sctp/sm_statefuns.c @@ -4044,7 +4044,7 @@ enum sctp_disposition sctp_sf_do_asconf_ack(struct net *net, (void *)err_param, commands); if (last_asconf) { - addip_hdr = (struct sctp_addiphdr *)last_asconf->subh.addip_hdr; + addip_hdr = last_asconf->subh.addip_hdr; sent_serial = ntohl(addip_hdr->serial); } else { sent_serial = asoc->addip_serial - 1; diff --git a/net/sctp/socket.c b/net/sctp/socket.c index 83628c347744..84021a6c4f9d 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -5098,13 +5098,17 @@ static void sctp_destroy_sock(struct sock *sk) } /* Triggered when there are no references on the socket anymore */ -static void sctp_destruct_sock(struct sock *sk) +static void sctp_destruct_common(struct sock *sk) { struct sctp_sock *sp = sctp_sk(sk); /* Free up the HMAC transform. */ crypto_free_shash(sp->hmac); +} +static void sctp_destruct_sock(struct sock *sk) +{ + sctp_destruct_common(sk); inet_sock_destruct(sk); } @@ -5311,14 +5315,14 @@ EXPORT_SYMBOL_GPL(sctp_for_each_endpoint); int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net, const union sctp_addr *laddr, - const union sctp_addr *paddr, void *p) + const union sctp_addr *paddr, void *p, int dif) { struct sctp_transport *transport; struct sctp_endpoint *ep; int err = -ENOENT; rcu_read_lock(); - transport = sctp_addrs_lookup_transport(net, laddr, paddr); + transport = sctp_addrs_lookup_transport(net, laddr, paddr, dif, dif); if (!transport) { rcu_read_unlock(); return err; @@ -8319,7 +8323,7 @@ static int sctp_get_port_local(struct sock *sk, union sctp_addr *addr) inet_get_local_port_range(net, &low, &high); remaining = (high - low) + 1; - rover = prandom_u32_max(remaining) + low; + rover = get_random_u32_below(remaining) + low; do { rover++; @@ -8394,6 +8398,7 @@ pp_found: * in an endpoint. */ sk_for_each_bound(sk2, &pp->owner) { + int bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if); struct sctp_sock *sp2 = sctp_sk(sk2); struct sctp_endpoint *ep2 = sp2->ep; @@ -8404,7 +8409,9 @@ pp_found: uid_eq(uid, sock_i_uid(sk2)))) continue; - if (sctp_bind_addr_conflict(&ep2->base.bind_addr, + if ((!sk->sk_bound_dev_if || !bound_dev_if2 || + sk->sk_bound_dev_if == bound_dev_if2) && + sctp_bind_addr_conflict(&ep2->base.bind_addr, addr, sp2, sp)) { ret = 1; goto fail_unlock; @@ -9427,7 +9434,7 @@ void sctp_copy_sock(struct sock *newsk, struct sock *sk, sctp_sk(newsk)->reuse = sp->reuse; newsk->sk_shutdown = sk->sk_shutdown; - newsk->sk_destruct = sctp_destruct_sock; + newsk->sk_destruct = sk->sk_destruct; newsk->sk_family = sk->sk_family; newsk->sk_protocol = IPPROTO_SCTP; newsk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; @@ -9662,11 +9669,20 @@ struct proto sctp_prot = { #if IS_ENABLED(CONFIG_IPV6) -#include <net/transp_v6.h> -static void sctp_v6_destroy_sock(struct sock *sk) +static void sctp_v6_destruct_sock(struct sock *sk) +{ + sctp_destruct_common(sk); + inet6_sock_destruct(sk); +} + +static int sctp_v6_init_sock(struct sock *sk) { - sctp_destroy_sock(sk); - inet6_destroy_sock(sk); + int ret = sctp_init_sock(sk); + + if (!ret) + sk->sk_destruct = sctp_v6_destruct_sock; + + return ret; } struct proto sctpv6_prot = { @@ -9676,8 +9692,8 @@ struct proto sctpv6_prot = { .disconnect = sctp_disconnect, .accept = sctp_accept, .ioctl = sctp_ioctl, - .init = sctp_init_sock, - .destroy = sctp_v6_destroy_sock, + .init = sctp_v6_init_sock, + .destroy = sctp_destroy_sock, .shutdown = sctp_shutdown, .setsockopt = sctp_setsockopt, .getsockopt = sctp_getsockopt, diff --git a/net/sctp/stream.c b/net/sctp/stream.c index ef9fceadef8d..ee6514af830f 100644 --- a/net/sctp/stream.c +++ b/net/sctp/stream.c @@ -52,6 +52,19 @@ static void sctp_stream_shrink_out(struct sctp_stream *stream, __u16 outcnt) } } +static void sctp_stream_free_ext(struct sctp_stream *stream, __u16 sid) +{ + struct sctp_sched_ops *sched; + + if (!SCTP_SO(stream, sid)->ext) + return; + + sched = sctp_sched_ops_from_stream(stream); + sched->free_sid(stream, sid); + kfree(SCTP_SO(stream, sid)->ext); + SCTP_SO(stream, sid)->ext = NULL; +} + /* Migrates chunks from stream queues to new stream queues if needed, * but not across associations. Also, removes those chunks to streams * higher than the new max. @@ -70,16 +83,14 @@ static void sctp_stream_outq_migrate(struct sctp_stream *stream, * sctp_stream_update will swap ->out pointers. */ for (i = 0; i < outcnt; i++) { - kfree(SCTP_SO(new, i)->ext); + sctp_stream_free_ext(new, i); SCTP_SO(new, i)->ext = SCTP_SO(stream, i)->ext; SCTP_SO(stream, i)->ext = NULL; } } - for (i = outcnt; i < stream->outcnt; i++) { - kfree(SCTP_SO(stream, i)->ext); - SCTP_SO(stream, i)->ext = NULL; - } + for (i = outcnt; i < stream->outcnt; i++) + sctp_stream_free_ext(stream, i); } static int sctp_stream_alloc_out(struct sctp_stream *stream, __u16 outcnt, @@ -174,9 +185,9 @@ void sctp_stream_free(struct sctp_stream *stream) struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); int i; - sched->free(stream); + sched->unsched_all(stream); for (i = 0; i < stream->outcnt; i++) - kfree(SCTP_SO(stream, i)->ext); + sctp_stream_free_ext(stream, i); genradix_free(&stream->out); genradix_free(&stream->in); } diff --git a/net/sctp/stream_interleave.c b/net/sctp/stream_interleave.c index bb22b71df7a3..94727feb07b3 100644 --- a/net/sctp/stream_interleave.c +++ b/net/sctp/stream_interleave.c @@ -490,11 +490,8 @@ static int sctp_enqueue_event(struct sctp_ulpq *ulpq, if (!sctp_ulpevent_is_enabled(event, ulpq->asoc->subscribe)) goto out_free; - if (skb_list) - skb_queue_splice_tail_init(skb_list, - &sk->sk_receive_queue); - else - __skb_queue_tail(&sk->sk_receive_queue, skb); + skb_queue_splice_tail_init(skb_list, + &sk->sk_receive_queue); if (!sp->data_ready_signalled) { sp->data_ready_signalled = 1; @@ -504,10 +501,7 @@ static int sctp_enqueue_event(struct sctp_ulpq *ulpq, return 1; out_free: - if (skb_list) - sctp_queue_purge_ulpevents(skb_list); - else - sctp_ulpevent_free(event); + sctp_queue_purge_ulpevents(skb_list); return 0; } diff --git a/net/sctp/stream_sched.c b/net/sctp/stream_sched.c index 1ad565ed5627..330067002deb 100644 --- a/net/sctp/stream_sched.c +++ b/net/sctp/stream_sched.c @@ -46,7 +46,7 @@ static int sctp_sched_fcfs_init_sid(struct sctp_stream *stream, __u16 sid, return 0; } -static void sctp_sched_fcfs_free(struct sctp_stream *stream) +static void sctp_sched_fcfs_free_sid(struct sctp_stream *stream, __u16 sid) { } @@ -96,7 +96,7 @@ static struct sctp_sched_ops sctp_sched_fcfs = { .get = sctp_sched_fcfs_get, .init = sctp_sched_fcfs_init, .init_sid = sctp_sched_fcfs_init_sid, - .free = sctp_sched_fcfs_free, + .free_sid = sctp_sched_fcfs_free_sid, .enqueue = sctp_sched_fcfs_enqueue, .dequeue = sctp_sched_fcfs_dequeue, .dequeue_done = sctp_sched_fcfs_dequeue_done, @@ -126,6 +126,23 @@ void sctp_sched_ops_init(void) sctp_sched_ops_rr_init(); } +static void sctp_sched_free_sched(struct sctp_stream *stream) +{ + struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + struct sctp_stream_out_ext *soute; + int i; + + sched->unsched_all(stream); + for (i = 0; i < stream->outcnt; i++) { + soute = SCTP_SO(stream, i)->ext; + if (!soute) + continue; + sched->free_sid(stream, i); + /* Give the next scheduler a clean slate. */ + memset_after(soute, 0, outq); + } +} + int sctp_sched_set_sched(struct sctp_association *asoc, enum sctp_sched_type sched) { @@ -141,18 +158,8 @@ int sctp_sched_set_sched(struct sctp_association *asoc, if (sched > SCTP_SS_MAX) return -EINVAL; - if (old) { - old->free(&asoc->stream); - - /* Give the next scheduler a clean slate. */ - for (i = 0; i < asoc->stream.outcnt; i++) { - struct sctp_stream_out_ext *ext = SCTP_SO(&asoc->stream, i)->ext; - - if (!ext) - continue; - memset_after(ext, 0, outq); - } - } + if (old) + sctp_sched_free_sched(&asoc->stream); asoc->outqueue.sched = n; n->init(&asoc->stream); @@ -176,7 +183,7 @@ int sctp_sched_set_sched(struct sctp_association *asoc, return ret; err: - n->free(&asoc->stream); + sctp_sched_free_sched(&asoc->stream); asoc->outqueue.sched = &sctp_sched_fcfs; /* Always safe */ return ret; diff --git a/net/sctp/stream_sched_prio.c b/net/sctp/stream_sched_prio.c index 80b5a2c4cbc7..42d4800f263d 100644 --- a/net/sctp/stream_sched_prio.c +++ b/net/sctp/stream_sched_prio.c @@ -204,30 +204,22 @@ static int sctp_sched_prio_init_sid(struct sctp_stream *stream, __u16 sid, return sctp_sched_prio_set(stream, sid, 0, gfp); } -static void sctp_sched_prio_free(struct sctp_stream *stream) +static void sctp_sched_prio_free_sid(struct sctp_stream *stream, __u16 sid) { - struct sctp_stream_priorities *prio, *n; - LIST_HEAD(list); + struct sctp_stream_priorities *prio = SCTP_SO(stream, sid)->ext->prio_head; int i; - /* As we don't keep a list of priorities, to avoid multiple - * frees we have to do it in 3 steps: - * 1. unsched everyone, so the lists are free to use in 2. - * 2. build the list of the priorities - * 3. free the list - */ - sctp_sched_prio_unsched_all(stream); + if (!prio) + return; + + SCTP_SO(stream, sid)->ext->prio_head = NULL; for (i = 0; i < stream->outcnt; i++) { - if (!SCTP_SO(stream, i)->ext) - continue; - prio = SCTP_SO(stream, i)->ext->prio_head; - if (prio && list_empty(&prio->prio_sched)) - list_add(&prio->prio_sched, &list); - } - list_for_each_entry_safe(prio, n, &list, prio_sched) { - list_del_init(&prio->prio_sched); - kfree(prio); + if (SCTP_SO(stream, i)->ext && + SCTP_SO(stream, i)->ext->prio_head == prio) + return; } + + kfree(prio); } static void sctp_sched_prio_enqueue(struct sctp_outq *q, @@ -323,7 +315,7 @@ static struct sctp_sched_ops sctp_sched_prio = { .get = sctp_sched_prio_get, .init = sctp_sched_prio_init, .init_sid = sctp_sched_prio_init_sid, - .free = sctp_sched_prio_free, + .free_sid = sctp_sched_prio_free_sid, .enqueue = sctp_sched_prio_enqueue, .dequeue = sctp_sched_prio_dequeue, .dequeue_done = sctp_sched_prio_dequeue_done, diff --git a/net/sctp/stream_sched_rr.c b/net/sctp/stream_sched_rr.c index ff425aed62c7..1f235e7f643a 100644 --- a/net/sctp/stream_sched_rr.c +++ b/net/sctp/stream_sched_rr.c @@ -90,9 +90,8 @@ static int sctp_sched_rr_init_sid(struct sctp_stream *stream, __u16 sid, return 0; } -static void sctp_sched_rr_free(struct sctp_stream *stream) +static void sctp_sched_rr_free_sid(struct sctp_stream *stream, __u16 sid) { - sctp_sched_rr_unsched_all(stream); } static void sctp_sched_rr_enqueue(struct sctp_outq *q, @@ -177,7 +176,7 @@ static struct sctp_sched_ops sctp_sched_rr = { .get = sctp_sched_rr_get, .init = sctp_sched_rr_init, .init_sid = sctp_sched_rr_init_sid, - .free = sctp_sched_rr_free, + .free_sid = sctp_sched_rr_free_sid, .enqueue = sctp_sched_rr_enqueue, .dequeue = sctp_sched_rr_dequeue, .dequeue_done = sctp_sched_rr_dequeue_done, diff --git a/net/sctp/sysctl.c b/net/sctp/sysctl.c index b46a416787ec..a7a9136198fd 100644 --- a/net/sctp/sysctl.c +++ b/net/sctp/sysctl.c @@ -84,17 +84,18 @@ static struct ctl_table sctp_table[] = { { /* sentinel */ } }; +/* The following index defines are used in sctp_sysctl_net_register(). + * If you add new items to the sctp_net_table, please ensure that + * the index values of these defines hold the same meaning indicated by + * their macro names when they appear in sctp_net_table. + */ +#define SCTP_RTO_MIN_IDX 0 +#define SCTP_RTO_MAX_IDX 1 +#define SCTP_PF_RETRANS_IDX 2 +#define SCTP_PS_RETRANS_IDX 3 + static struct ctl_table sctp_net_table[] = { - { - .procname = "rto_initial", - .data = &init_net.sctp.rto_initial, - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_minmax, - .extra1 = SYSCTL_ONE, - .extra2 = &timer_max - }, - { + [SCTP_RTO_MIN_IDX] = { .procname = "rto_min", .data = &init_net.sctp.rto_min, .maxlen = sizeof(unsigned int), @@ -103,7 +104,7 @@ static struct ctl_table sctp_net_table[] = { .extra1 = SYSCTL_ONE, .extra2 = &init_net.sctp.rto_max }, - { + [SCTP_RTO_MAX_IDX] = { .procname = "rto_max", .data = &init_net.sctp.rto_max, .maxlen = sizeof(unsigned int), @@ -112,6 +113,33 @@ static struct ctl_table sctp_net_table[] = { .extra1 = &init_net.sctp.rto_min, .extra2 = &timer_max }, + [SCTP_PF_RETRANS_IDX] = { + .procname = "pf_retrans", + .data = &init_net.sctp.pf_retrans, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &init_net.sctp.ps_retrans, + }, + [SCTP_PS_RETRANS_IDX] = { + .procname = "ps_retrans", + .data = &init_net.sctp.ps_retrans, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &init_net.sctp.pf_retrans, + .extra2 = &ps_retrans_max, + }, + { + .procname = "rto_initial", + .data = &init_net.sctp.rto_initial, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &timer_max + }, { .procname = "rto_alpha_exp_divisor", .data = &init_net.sctp.rto_alpha, @@ -208,24 +236,6 @@ static struct ctl_table sctp_net_table[] = { .extra2 = SYSCTL_INT_MAX, }, { - .procname = "pf_retrans", - .data = &init_net.sctp.pf_retrans, - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec_minmax, - .extra1 = SYSCTL_ZERO, - .extra2 = &init_net.sctp.ps_retrans, - }, - { - .procname = "ps_retrans", - .data = &init_net.sctp.ps_retrans, - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec_minmax, - .extra1 = &init_net.sctp.pf_retrans, - .extra2 = &ps_retrans_max, - }, - { .procname = "sndbuf_policy", .data = &init_net.sctp.sndbuf_policy, .maxlen = sizeof(int), @@ -347,6 +357,17 @@ static struct ctl_table sctp_net_table[] = { .extra1 = &max_autoclose_min, .extra2 = &max_autoclose_max, }, +#ifdef CONFIG_NET_L3_MASTER_DEV + { + .procname = "l3mdev_accept", + .data = &init_net.sctp.l3mdev_accept, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif { .procname = "pf_enable", .data = &init_net.sctp.pf_enable, @@ -586,6 +607,11 @@ int sctp_sysctl_net_register(struct net *net) for (i = 0; table[i].data; i++) table[i].data += (char *)(&net->sctp) - (char *)&init_net.sctp; + table[SCTP_RTO_MIN_IDX].extra2 = &net->sctp.rto_max; + table[SCTP_RTO_MAX_IDX].extra1 = &net->sctp.rto_min; + table[SCTP_PF_RETRANS_IDX].extra2 = &net->sctp.ps_retrans; + table[SCTP_PS_RETRANS_IDX].extra1 = &net->sctp.pf_retrans; + net->sctp.sysctl_header = register_net_sysctl(net, "net/sctp", table); if (net->sctp.sysctl_header == NULL) { kfree(table); diff --git a/net/sctp/transport.c b/net/sctp/transport.c index f8fd98784977..2f66a2006517 100644 --- a/net/sctp/transport.c +++ b/net/sctp/transport.c @@ -196,10 +196,8 @@ void sctp_transport_reset_hb_timer(struct sctp_transport *transport) /* When a data chunk is sent, reset the heartbeat interval. */ expires = jiffies + sctp_transport_timeout(transport); - if ((time_before(transport->hb_timer.expires, expires) || - !timer_pending(&transport->hb_timer)) && - !mod_timer(&transport->hb_timer, - expires + prandom_u32_max(transport->rto))) + if (!mod_timer(&transport->hb_timer, + expires + get_random_u32_below(transport->rto))) sctp_transport_hold(transport); } diff --git a/net/sctp/ulpqueue.c b/net/sctp/ulpqueue.c index 0a8510a0c5e6..b05daafd369a 100644 --- a/net/sctp/ulpqueue.c +++ b/net/sctp/ulpqueue.c @@ -38,8 +38,7 @@ static void sctp_ulpq_reasm_drain(struct sctp_ulpq *ulpq); /* 1st Level Abstractions */ /* Initialize a ULP queue from a block of memory. */ -struct sctp_ulpq *sctp_ulpq_init(struct sctp_ulpq *ulpq, - struct sctp_association *asoc) +void sctp_ulpq_init(struct sctp_ulpq *ulpq, struct sctp_association *asoc) { memset(ulpq, 0, sizeof(struct sctp_ulpq)); @@ -48,8 +47,6 @@ struct sctp_ulpq *sctp_ulpq_init(struct sctp_ulpq *ulpq, skb_queue_head_init(&ulpq->reasm_uo); skb_queue_head_init(&ulpq->lobby); ulpq->pd_mode = 0; - - return ulpq; } @@ -259,10 +256,7 @@ int sctp_ulpq_tail_event(struct sctp_ulpq *ulpq, struct sk_buff_head *skb_list) return 1; out_free: - if (skb_list) - sctp_queue_purge_ulpevents(skb_list); - else - sctp_ulpevent_free(event); + sctp_queue_purge_ulpevents(skb_list); return 0; } diff --git a/net/smc/smc_clc.c b/net/smc/smc_clc.c index 1472f31480d8..dfb9797f7bc6 100644 --- a/net/smc/smc_clc.c +++ b/net/smc/smc_clc.c @@ -673,7 +673,7 @@ int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, */ krflags = MSG_PEEK | MSG_WAITALL; clc_sk->sk_rcvtimeo = timeout; - iov_iter_kvec(&msg.msg_iter, READ, &vec, 1, + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, sizeof(struct smc_clc_msg_hdr)); len = sock_recvmsg(smc->clcsock, &msg, krflags); if (signal_pending(current)) { @@ -720,7 +720,7 @@ int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, } else { recvlen = datlen; } - iov_iter_kvec(&msg.msg_iter, READ, &vec, 1, recvlen); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, recvlen); krflags = MSG_WAITALL; len = sock_recvmsg(smc->clcsock, &msg, krflags); if (len < recvlen || !smc_clc_msg_hdr_valid(clcm, check_trl)) { @@ -737,7 +737,7 @@ int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, /* receive remaining proposal message */ recvlen = datlen > SMC_CLC_RECV_BUF_LEN ? SMC_CLC_RECV_BUF_LEN : datlen; - iov_iter_kvec(&msg.msg_iter, READ, &vec, 1, recvlen); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, recvlen); len = sock_recvmsg(smc->clcsock, &msg, krflags); datlen -= len; } diff --git a/net/smc/smc_tx.c b/net/smc/smc_tx.c index 64dedffe9d26..f4b6a71ac488 100644 --- a/net/smc/smc_tx.c +++ b/net/smc/smc_tx.c @@ -308,7 +308,7 @@ int smc_tx_sendpage(struct smc_sock *smc, struct page *page, int offset, iov.iov_base = kaddr + offset; iov.iov_len = size; - iov_iter_kvec(&msg.msg_iter, WRITE, &iov, 1, size); + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, size); rc = smc_tx_sendmsg(smc, &msg, size); kunmap(page); return rc; diff --git a/net/socket.c b/net/socket.c index 00da9ce3dba0..888cd618a968 100644 --- a/net/socket.c +++ b/net/socket.c @@ -750,7 +750,7 @@ EXPORT_SYMBOL(sock_sendmsg); int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t num, size_t size) { - iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); return sock_sendmsg(sock, msg); } EXPORT_SYMBOL(kernel_sendmsg); @@ -776,7 +776,7 @@ int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg, if (!sock->ops->sendmsg_locked) return sock_no_sendmsg_locked(sk, msg, size); - iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg)); } @@ -1034,7 +1034,7 @@ int kernel_recvmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t num, size_t size, int flags) { msg->msg_control_is_user = false; - iov_iter_kvec(&msg->msg_iter, READ, vec, num, size); + iov_iter_kvec(&msg->msg_iter, ITER_DEST, vec, num, size); return sock_recvmsg(sock, msg, flags); } EXPORT_SYMBOL(kernel_recvmsg); @@ -2092,7 +2092,7 @@ int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags, struct iovec iov; int fput_needed; - err = import_single_range(WRITE, buff, len, &iov, &msg.msg_iter); + err = import_single_range(ITER_SOURCE, buff, len, &iov, &msg.msg_iter); if (unlikely(err)) return err; sock = sockfd_lookup_light(fd, &err, &fput_needed); @@ -2157,7 +2157,7 @@ int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags, int err, err2; int fput_needed; - err = import_single_range(READ, ubuf, size, &iov, &msg.msg_iter); + err = import_single_range(ITER_DEST, ubuf, size, &iov, &msg.msg_iter); if (unlikely(err)) return err; sock = sockfd_lookup_light(fd, &err, &fput_needed); @@ -2199,13 +2199,7 @@ SYSCALL_DEFINE4(recv, int, fd, void __user *, ubuf, size_t, size, static bool sock_use_custom_sol_socket(const struct socket *sock) { - const struct sock *sk = sock->sk; - - /* Use sock->ops->setsockopt() for MPTCP */ - return IS_ENABLED(CONFIG_MPTCP) && - sk->sk_protocol == IPPROTO_MPTCP && - sk->sk_type == SOCK_STREAM && - (sk->sk_family == AF_INET || sk->sk_family == AF_INET6); + return test_bit(SOCK_CUSTOM_SOCKOPT, &sock->flags); } /* @@ -2417,7 +2411,7 @@ static int copy_msghdr_from_user(struct msghdr *kmsg, if (err) return err; - err = import_iovec(save_addr ? READ : WRITE, + err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV, iov, &kmsg->msg_iter); return err < 0 ? err : 0; diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index 7bb247c51e2f..2d7b1e03110a 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -302,7 +302,7 @@ __gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth list_for_each_entry(pos, &pipe->in_downcall, list) { if (!uid_eq(pos->uid, uid)) continue; - if (auth && pos->auth->service != auth->service) + if (pos->auth->service != auth->service) continue; refcount_inc(&pos->count); return pos; @@ -686,6 +686,21 @@ out: return err; } +static struct gss_upcall_msg * +gss_find_downcall(struct rpc_pipe *pipe, kuid_t uid) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) + continue; + if (!rpc_msg_is_inflight(&pos->msg)) + continue; + refcount_inc(&pos->count); + return pos; + } + return NULL; +} + #define MSG_BUF_MAXSIZE 1024 static ssize_t @@ -732,7 +747,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) err = -ENOENT; /* Find a matching upcall */ spin_lock(&pipe->lock); - gss_msg = __gss_find_upcall(pipe, uid, NULL); + gss_msg = gss_find_downcall(pipe, uid); if (gss_msg == NULL) { spin_unlock(&pipe->lock); goto err_put_ctx; diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index bcd74dddbe2d..acb822b23af1 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -49,11 +49,36 @@ #include <linux/sunrpc/svcauth.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/cache.h> +#include <linux/sunrpc/gss_krb5.h> #include <trace/events/rpcgss.h> #include "gss_rpc_upcall.h" +/* + * Unfortunately there isn't a maximum checksum size exported via the + * GSS API. Manufacture one based on GSS mechanisms supported by this + * implementation. + */ +#define GSS_MAX_CKSUMSIZE (GSS_KRB5_TOK_HDR_LEN + GSS_KRB5_MAX_CKSUM_LEN) + +/* + * This value may be increased in the future to accommodate other + * usage of the scratch buffer. + */ +#define GSS_SCRATCH_SIZE GSS_MAX_CKSUMSIZE + +struct gss_svc_data { + /* decoded gss client cred: */ + struct rpc_gss_wire_cred clcred; + /* save a pointer to the beginning of the encoded verifier, + * for use in encryption/checksumming in svcauth_gss_release: */ + __be32 *verf_start; + struct rsc *rsci; + + /* for temporary results */ + u8 gsd_scratch[GSS_SCRATCH_SIZE]; +}; /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests * into replies. @@ -887,20 +912,18 @@ read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) static int unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) { + struct gss_svc_data *gsd = rqstp->rq_auth_data; u32 integ_len, rseqno, maj_stat; - int stat = -EINVAL; struct xdr_netobj mic; struct xdr_buf integ_buf; - mic.data = NULL; - /* NFS READ normally uses splice to send data in-place. However * the data in cache can change after the reply's MIC is computed * but before the RPC reply is sent. To prevent the client from * rejecting the server-computed MIC in this somewhat rare case, * do not use splice with the GSS integrity service. */ - __clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); /* Did we already verify the signature on the original pass through? */ if (rqstp->rq_deferred) @@ -917,11 +940,9 @@ unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct g /* copy out mic... */ if (read_u32_from_xdr_buf(buf, integ_len, &mic.len)) goto unwrap_failed; - if (mic.len > RPC_MAX_AUTH_SIZE) - goto unwrap_failed; - mic.data = kmalloc(mic.len, GFP_KERNEL); - if (!mic.data) + if (mic.len > sizeof(gsd->gsd_scratch)) goto unwrap_failed; + mic.data = gsd->gsd_scratch; if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len)) goto unwrap_failed; maj_stat = gss_verify_mic(ctx, &integ_buf, &mic); @@ -932,20 +953,17 @@ unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct g goto bad_seqno; /* trim off the mic and padding at the end before returning */ xdr_buf_trim(buf, round_up_to_quad(mic.len) + 4); - stat = 0; -out: - kfree(mic.data); - return stat; + return 0; unwrap_failed: trace_rpcgss_svc_unwrap_failed(rqstp); - goto out; + return -EINVAL; bad_seqno: trace_rpcgss_svc_seqno_bad(rqstp, seq, rseqno); - goto out; + return -EINVAL; bad_mic: trace_rpcgss_svc_mic(rqstp, maj_stat); - goto out; + return -EINVAL; } static inline int @@ -972,7 +990,7 @@ unwrap_priv_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gs int pad, remaining_len, offset; u32 rseqno; - __clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); priv_len = svc_getnl(&buf->head[0]); if (rqstp->rq_deferred) { @@ -1023,15 +1041,6 @@ bad_unwrap: return -EINVAL; } -struct gss_svc_data { - /* decoded gss client cred: */ - struct rpc_gss_wire_cred clcred; - /* save a pointer to the beginning of the encoded verifier, - * for use in encryption/checksumming in svcauth_gss_release: */ - __be32 *verf_start; - struct rsc *rsci; -}; - static int svcauth_gss_set_client(struct svc_rqst *rqstp) { @@ -1162,18 +1171,23 @@ static int gss_read_proxy_verf(struct svc_rqst *rqstp, return res; inlen = svc_getnl(argv); - if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) + if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) { + kfree(in_handle->data); return SVC_DENIED; + } pages = DIV_ROUND_UP(inlen, PAGE_SIZE); in_token->pages = kcalloc(pages, sizeof(struct page *), GFP_KERNEL); - if (!in_token->pages) + if (!in_token->pages) { + kfree(in_handle->data); return SVC_DENIED; + } in_token->page_base = 0; in_token->page_len = inlen; for (i = 0; i < pages; i++) { in_token->pages[i] = alloc_page(GFP_KERNEL); if (!in_token->pages[i]) { + kfree(in_handle->data); gss_free_in_token_pages(in_token); return SVC_DENIED; } diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index f075a9fb5ccc..95ff74706104 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -677,7 +677,7 @@ static void cache_limit_defers(void) /* Consider removing either the first or the last */ if (cache_defer_cnt > DFR_MAX) { - if (prandom_u32_max(2)) + if (get_random_u32_below(2)) discard = list_entry(cache_defer_list.next, struct cache_deferred_req, recent); else diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 993acf38af87..0b0b9f1eed46 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -1442,7 +1442,7 @@ static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, break; default: err = -EAFNOSUPPORT; - goto out; + goto out_release; } if (err < 0) { dprintk("RPC: can't bind UDP socket (%d)\n", err); diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c index 71ba4cf513bc..1b2b84feeec6 100644 --- a/net/sunrpc/socklib.c +++ b/net/sunrpc/socklib.c @@ -214,14 +214,14 @@ static inline int xprt_sendmsg(struct socket *sock, struct msghdr *msg, static int xprt_send_kvec(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t seek) { - iov_iter_kvec(&msg->msg_iter, WRITE, vec, 1, vec->iov_len); + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, 1, vec->iov_len); return xprt_sendmsg(sock, msg, seek); } static int xprt_send_pagedata(struct socket *sock, struct msghdr *msg, struct xdr_buf *xdr, size_t base) { - iov_iter_bvec(&msg->msg_iter, WRITE, xdr->bvec, xdr_buf_pagecount(xdr), + iov_iter_bvec(&msg->msg_iter, ITER_SOURCE, xdr->bvec, xdr_buf_pagecount(xdr), xdr->page_len + xdr->page_base); return xprt_sendmsg(sock, msg, base + xdr->page_base); } @@ -244,7 +244,7 @@ static int xprt_send_rm_and_kvec(struct socket *sock, struct msghdr *msg, }; size_t len = iov[0].iov_len + iov[1].iov_len; - iov_iter_kvec(&msg->msg_iter, WRITE, iov, 2, len); + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, iov, 2, len); return xprt_sendmsg(sock, msg, base); } diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index 149171774bc6..f06622814a95 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -567,7 +567,7 @@ svc_destroy(struct kref *ref) struct svc_serv *serv = container_of(ref, struct svc_serv, sv_refcnt); dprintk("svc: svc_destroy(%s)\n", serv->sv_program->pg_name); - del_timer_sync(&serv->sv_temptimer); + timer_shutdown_sync(&serv->sv_temptimer); /* * The last user is gone and thus all sockets have to be destroyed to @@ -638,7 +638,6 @@ svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) return rqstp; __set_bit(RQ_BUSY, &rqstp->rq_flags); - spin_lock_init(&rqstp->rq_lock); rqstp->rq_server = serv; rqstp->rq_pool = pool; @@ -1244,10 +1243,10 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) goto err_short_len; /* Will be turned off by GSS integrity and privacy services */ - __set_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + set_bit(RQ_SPLICE_OK, &rqstp->rq_flags); /* Will be turned off only when NFSv4 Sessions are used */ - __set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); - __clear_bit(RQ_DROPME, &rqstp->rq_flags); + set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); + clear_bit(RQ_DROPME, &rqstp->rq_flags); svc_putu32(resv, rqstp->rq_xid); @@ -1281,8 +1280,7 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) /* Also give the program a chance to reject this call: */ if (auth_res == SVC_OK && progp) auth_res = progp->pg_authenticate(rqstp); - if (auth_res != SVC_OK) - trace_svc_authenticate(rqstp, auth_res); + trace_svc_authenticate(rqstp, auth_res); switch (auth_res) { case SVC_OK: break; diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index 2106003645a7..c2ce12538008 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -1238,7 +1238,7 @@ static struct cache_deferred_req *svc_defer(struct cache_req *req) trace_svc_defer(rqstp); svc_xprt_get(rqstp->rq_xprt); dr->xprt = rqstp->rq_xprt; - __set_bit(RQ_DROPME, &rqstp->rq_flags); + set_bit(RQ_DROPME, &rqstp->rq_flags); dr->handle.revisit = svc_revisit; return &dr->handle; diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 2fc98fea59b4..815baf308236 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -260,7 +260,7 @@ static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen, rqstp->rq_respages = &rqstp->rq_pages[i]; rqstp->rq_next_page = rqstp->rq_respages + 1; - iov_iter_bvec(&msg.msg_iter, READ, bvec, i, buflen); + iov_iter_bvec(&msg.msg_iter, ITER_DEST, bvec, i, buflen); if (seek) { iov_iter_advance(&msg.msg_iter, seek); buflen -= seek; @@ -298,9 +298,9 @@ static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs) static void svc_sock_secure_port(struct svc_rqst *rqstp) { if (svc_port_is_privileged(svc_addr(rqstp))) - __set_bit(RQ_SECURE, &rqstp->rq_flags); + set_bit(RQ_SECURE, &rqstp->rq_flags); else - __clear_bit(RQ_SECURE, &rqstp->rq_flags); + clear_bit(RQ_SECURE, &rqstp->rq_flags); } /* @@ -874,7 +874,7 @@ static ssize_t svc_tcp_read_marker(struct svc_sock *svsk, want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen; iov.iov_len = want; - iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, want); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want); len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT); if (len < 0) return len; @@ -1008,9 +1008,9 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) rqstp->rq_xprt_ctxt = NULL; rqstp->rq_prot = IPPROTO_TCP; if (test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags)) - __set_bit(RQ_LOCAL, &rqstp->rq_flags); + set_bit(RQ_LOCAL, &rqstp->rq_flags); else - __clear_bit(RQ_LOCAL, &rqstp->rq_flags); + clear_bit(RQ_LOCAL, &rqstp->rq_flags); p = (__be32 *)rqstp->rq_arg.head[0].iov_base; calldir = p[1]; diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c index c1f559892ae8..1e05a2d723f4 100644 --- a/net/sunrpc/sysfs.c +++ b/net/sunrpc/sysfs.c @@ -31,7 +31,7 @@ static void rpc_sysfs_object_release(struct kobject *kobj) } static const struct kobj_ns_type_operations * -rpc_sysfs_object_child_ns_type(struct kobject *kobj) +rpc_sysfs_object_child_ns_type(const struct kobject *kobj) { return &net_ns_type_operations; } @@ -381,17 +381,17 @@ static void rpc_sysfs_xprt_release(struct kobject *kobj) kfree(xprt); } -static const void *rpc_sysfs_client_namespace(struct kobject *kobj) +static const void *rpc_sysfs_client_namespace(const struct kobject *kobj) { return container_of(kobj, struct rpc_sysfs_client, kobject)->net; } -static const void *rpc_sysfs_xprt_switch_namespace(struct kobject *kobj) +static const void *rpc_sysfs_xprt_switch_namespace(const struct kobject *kobj) { return container_of(kobj, struct rpc_sysfs_xprt_switch, kobject)->net; } -static const void *rpc_sysfs_xprt_namespace(struct kobject *kobj) +static const void *rpc_sysfs_xprt_namespace(const struct kobject *kobj) { return container_of(kobj, struct rpc_sysfs_xprt, kobject)->xprt->xprt_net; diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index 336a7c7833e4..f7767bf22406 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -1224,30 +1224,34 @@ EXPORT_SYMBOL(xdr_restrict_buflen); /** * xdr_write_pages - Insert a list of pages into an XDR buffer for sending * @xdr: pointer to xdr_stream - * @pages: list of pages - * @base: offset of first byte - * @len: length of data in bytes + * @pages: array of pages to insert + * @base: starting offset of first data byte in @pages + * @len: number of data bytes in @pages to insert * + * After the @pages are added, the tail iovec is instantiated pointing to + * end of the head buffer, and the stream is set up to encode subsequent + * items into the tail. */ void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base, unsigned int len) { struct xdr_buf *buf = xdr->buf; - struct kvec *iov = buf->tail; + struct kvec *tail = buf->tail; + buf->pages = pages; buf->page_base = base; buf->page_len = len; - iov->iov_base = (char *)xdr->p; - iov->iov_len = 0; - xdr->iov = iov; + tail->iov_base = xdr->p; + tail->iov_len = 0; + xdr->iov = tail; if (len & 3) { unsigned int pad = 4 - (len & 3); BUG_ON(xdr->p >= xdr->end); - iov->iov_base = (char *)xdr->p + (len & 3); - iov->iov_len += pad; + tail->iov_base = (char *)xdr->p + (len & 3); + tail->iov_len += pad; len += pad; *xdr->p++ = 0; } diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index 656cec208371..ab453ede54f0 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -1164,7 +1164,7 @@ xprt_request_enqueue_receive(struct rpc_task *task) spin_unlock(&xprt->queue_lock); /* Turn off autodisconnect */ - del_singleshot_timer_sync(&xprt->timer); + del_timer_sync(&xprt->timer); return 0; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 199fa012f18a..94b20fb47135 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -602,7 +602,7 @@ static int svc_rdma_has_wspace(struct svc_xprt *xprt) static void svc_rdma_secure_port(struct svc_rqst *rqstp) { - __set_bit(RQ_SECURE, &rqstp->rq_flags); + set_bit(RQ_SECURE, &rqstp->rq_flags); } static void svc_rdma_kill_temp_xprt(struct svc_xprt *xprt) diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 44b87e4274b4..b098fde373ab 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -831,7 +831,7 @@ struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt, return req; out3: - kfree(req->rl_sendbuf); + rpcrdma_regbuf_free(req->rl_sendbuf); out2: kfree(req); out1: diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 915b9902f673..aaa5b2741b79 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -364,7 +364,7 @@ static ssize_t xs_read_kvec(struct socket *sock, struct msghdr *msg, int flags, struct kvec *kvec, size_t count, size_t seek) { - iov_iter_kvec(&msg->msg_iter, READ, kvec, 1, count); + iov_iter_kvec(&msg->msg_iter, ITER_DEST, kvec, 1, count); return xs_sock_recvmsg(sock, msg, flags, seek); } @@ -373,7 +373,7 @@ xs_read_bvec(struct socket *sock, struct msghdr *msg, int flags, struct bio_vec *bvec, unsigned long nr, size_t count, size_t seek) { - iov_iter_bvec(&msg->msg_iter, READ, bvec, nr, count); + iov_iter_bvec(&msg->msg_iter, ITER_DEST, bvec, nr, count); return xs_sock_recvmsg(sock, msg, flags, seek); } @@ -381,7 +381,7 @@ static ssize_t xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, size_t count) { - iov_iter_discard(&msg->msg_iter, READ, count); + iov_iter_discard(&msg->msg_iter, ITER_DEST, count); return sock_recvmsg(sock, msg, flags); } @@ -1619,7 +1619,7 @@ static int xs_get_random_port(void) if (max < min) return -EADDRINUSE; range = max - min + 1; - rand = prandom_u32_max(range); + rand = get_random_u32_below(range); return rand + min; } @@ -1882,6 +1882,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, sk->sk_write_space = xs_udp_write_space; sk->sk_state_change = xs_local_state_change; sk->sk_error_report = xs_error_report; + sk->sk_use_task_frag = false; xprt_clear_connected(xprt); @@ -2082,6 +2083,7 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) sk->sk_user_data = xprt; sk->sk_data_ready = xs_data_ready; sk->sk_write_space = xs_udp_write_space; + sk->sk_use_task_frag = false; xprt_set_connected(xprt); @@ -2249,6 +2251,7 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) sk->sk_state_change = xs_tcp_state_change; sk->sk_write_space = xs_tcp_write_space; sk->sk_error_report = xs_error_report; + sk->sk_use_task_frag = false; /* socket options */ sock_reset_flag(sk, SOCK_LINGER); diff --git a/net/tipc/crypto.c b/net/tipc/crypto.c index f09316a9035f..d67440de011e 100644 --- a/net/tipc/crypto.c +++ b/net/tipc/crypto.c @@ -1971,6 +1971,9 @@ rcv: /* Ok, everything's fine, try to synch own keys according to peers' */ tipc_crypto_key_synch(rx, *skb); + /* Re-fetch skb cb as skb might be changed in tipc_msg_validate */ + skb_cb = TIPC_SKB_CB(*skb); + /* Mark skb decrypted */ skb_cb->decrypted = 1; diff --git a/net/tipc/discover.c b/net/tipc/discover.c index e8630707901e..685389d4b245 100644 --- a/net/tipc/discover.c +++ b/net/tipc/discover.c @@ -211,7 +211,10 @@ void tipc_disc_rcv(struct net *net, struct sk_buff *skb, u32 self; int err; - skb_linearize(skb); + if (skb_linearize(skb)) { + kfree_skb(skb); + return; + } hdr = buf_msg(skb); if (caps & TIPC_NODE_ID128) @@ -385,7 +388,7 @@ int tipc_disc_create(struct net *net, struct tipc_bearer *b, */ void tipc_disc_delete(struct tipc_discoverer *d) { - del_timer_sync(&d->timer); + timer_shutdown_sync(&d->timer); kfree_skb(d->skb); kfree(d); } diff --git a/net/tipc/link.c b/net/tipc/link.c index e260c0d557f5..b3ce24823f50 100644 --- a/net/tipc/link.c +++ b/net/tipc/link.c @@ -2224,7 +2224,9 @@ static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, if (tipc_own_addr(l->net) > msg_prevnode(hdr)) l->net_plane = msg_net_plane(hdr); - skb_linearize(skb); + if (skb_linearize(skb)) + goto exit; + hdr = buf_msg(skb); data = msg_data(hdr); diff --git a/net/tipc/monitor.c b/net/tipc/monitor.c index 9618e4429f0f..77a3d016cade 100644 --- a/net/tipc/monitor.c +++ b/net/tipc/monitor.c @@ -700,7 +700,7 @@ void tipc_mon_delete(struct net *net, int bearer_id) } mon->self = NULL; write_unlock_bh(&mon->lock); - del_timer_sync(&mon->timer); + timer_shutdown_sync(&mon->timer); kfree(self->domain); kfree(self); kfree(mon); diff --git a/net/tipc/node.c b/net/tipc/node.c index b48d97cbbe29..5e000fde8067 100644 --- a/net/tipc/node.c +++ b/net/tipc/node.c @@ -1179,8 +1179,9 @@ void tipc_node_check_dest(struct net *net, u32 addr, bool addr_match = false; bool sign_match = false; bool link_up = false; + bool link_is_reset = false; bool accept_addr = false; - bool reset = true; + bool reset = false; char *if_name; unsigned long intv; u16 session; @@ -1200,14 +1201,14 @@ void tipc_node_check_dest(struct net *net, u32 addr, /* Prepare to validate requesting node's signature and media address */ l = le->link; link_up = l && tipc_link_is_up(l); + link_is_reset = l && tipc_link_is_reset(l); addr_match = l && !memcmp(&le->maddr, maddr, sizeof(*maddr)); sign_match = (signature == n->signature); /* These three flags give us eight permutations: */ if (sign_match && addr_match && link_up) { - /* All is fine. Do nothing. */ - reset = false; + /* All is fine. Ignore requests. */ /* Peer node is not a container/local namespace */ if (!n->peer_hash_mix) n->peer_hash_mix = hash_mixes; @@ -1232,6 +1233,7 @@ void tipc_node_check_dest(struct net *net, u32 addr, */ accept_addr = true; *respond = true; + reset = true; } else if (!sign_match && addr_match && link_up) { /* Peer node rebooted. Two possibilities: * - Delayed re-discovery; this link endpoint has already @@ -1263,6 +1265,7 @@ void tipc_node_check_dest(struct net *net, u32 addr, n->signature = signature; accept_addr = true; *respond = true; + reset = true; } if (!accept_addr) @@ -1291,6 +1294,7 @@ void tipc_node_check_dest(struct net *net, u32 addr, tipc_link_fsm_evt(l, LINK_RESET_EVT); if (n->state == NODE_FAILINGOVER) tipc_link_fsm_evt(l, LINK_FAILOVER_BEGIN_EVT); + link_is_reset = tipc_link_is_reset(l); le->link = l; n->link_cnt++; tipc_node_calculate_timer(n, l); @@ -1303,7 +1307,7 @@ void tipc_node_check_dest(struct net *net, u32 addr, memcpy(&le->maddr, maddr, sizeof(*maddr)); exit: tipc_node_write_unlock(n); - if (reset && l && !tipc_link_is_reset(l)) + if (reset && !link_is_reset) tipc_node_link_down(n, b->identity, false); tipc_node_put(n); } @@ -1689,6 +1693,7 @@ int tipc_node_xmit(struct net *net, struct sk_buff_head *list, struct tipc_node *n; struct sk_buff_head xmitq; bool node_up = false; + struct net *peer_net; int bearer_id; int rc; @@ -1705,18 +1710,23 @@ int tipc_node_xmit(struct net *net, struct sk_buff_head *list, return -EHOSTUNREACH; } + rcu_read_lock(); tipc_node_read_lock(n); node_up = node_is_up(n); - if (node_up && n->peer_net && check_net(n->peer_net)) { + peer_net = n->peer_net; + tipc_node_read_unlock(n); + if (node_up && peer_net && check_net(peer_net)) { /* xmit inner linux container */ - tipc_lxc_xmit(n->peer_net, list); + tipc_lxc_xmit(peer_net, list); if (likely(skb_queue_empty(list))) { - tipc_node_read_unlock(n); + rcu_read_unlock(); tipc_node_put(n); return 0; } } + rcu_read_unlock(); + tipc_node_read_lock(n); bearer_id = n->active_links[selector & 1]; if (unlikely(bearer_id == INVALID_BEARER_ID)) { tipc_node_read_unlock(n); diff --git a/net/tipc/socket.c b/net/tipc/socket.c index e902b01ea3cb..b35c8701876a 100644 --- a/net/tipc/socket.c +++ b/net/tipc/socket.c @@ -3010,7 +3010,7 @@ static int tipc_sk_insert(struct tipc_sock *tsk) struct net *net = sock_net(sk); struct tipc_net *tn = net_generic(net, tipc_net_id); u32 remaining = (TIPC_MAX_PORT - TIPC_MIN_PORT) + 1; - u32 portid = prandom_u32_max(remaining) + TIPC_MIN_PORT; + u32 portid = get_random_u32_below(remaining) + TIPC_MIN_PORT; while (remaining--) { portid++; diff --git a/net/tipc/topsrv.c b/net/tipc/topsrv.c index d92ec92f0b71..69c88cc03887 100644 --- a/net/tipc/topsrv.c +++ b/net/tipc/topsrv.c @@ -176,7 +176,7 @@ static void tipc_conn_close(struct tipc_conn *con) conn_put(con); } -static struct tipc_conn *tipc_conn_alloc(struct tipc_topsrv *s) +static struct tipc_conn *tipc_conn_alloc(struct tipc_topsrv *s, struct socket *sock) { struct tipc_conn *con; int ret; @@ -202,10 +202,12 @@ static struct tipc_conn *tipc_conn_alloc(struct tipc_topsrv *s) } con->conid = ret; s->idr_in_use++; - spin_unlock_bh(&s->idr_lock); set_bit(CF_CONNECTED, &con->flags); con->server = s; + con->sock = sock; + conn_get(con); + spin_unlock_bh(&s->idr_lock); return con; } @@ -394,7 +396,7 @@ static int tipc_conn_rcv_from_sock(struct tipc_conn *con) iov.iov_base = &s; iov.iov_len = sizeof(s); msg.msg_name = NULL; - iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, iov.iov_len); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, iov.iov_len); ret = sock_recvmsg(con->sock, &msg, MSG_DONTWAIT); if (ret == -EWOULDBLOCK) return -EWOULDBLOCK; @@ -467,7 +469,7 @@ static void tipc_topsrv_accept(struct work_struct *work) ret = kernel_accept(lsock, &newsock, O_NONBLOCK); if (ret < 0) return; - con = tipc_conn_alloc(srv); + con = tipc_conn_alloc(srv, newsock); if (IS_ERR(con)) { ret = PTR_ERR(con); sock_release(newsock); @@ -479,11 +481,11 @@ static void tipc_topsrv_accept(struct work_struct *work) newsk->sk_data_ready = tipc_conn_data_ready; newsk->sk_write_space = tipc_conn_write_space; newsk->sk_user_data = con; - con->sock = newsock; write_unlock_bh(&newsk->sk_callback_lock); /* Wake up receive process in case of 'SYN+' message */ newsk->sk_data_ready(newsk); + conn_put(con); } } @@ -577,17 +579,17 @@ bool tipc_topsrv_kern_subscr(struct net *net, u32 port, u32 type, u32 lower, sub.filter = filter; *(u64 *)&sub.usr_handle = (u64)port; - con = tipc_conn_alloc(tipc_topsrv(net)); + con = tipc_conn_alloc(tipc_topsrv(net), NULL); if (IS_ERR(con)) return false; *conid = con->conid; - con->sock = NULL; rc = tipc_conn_rcv_sub(tipc_topsrv(net), con, &sub); - if (rc >= 0) - return true; + if (rc) + conn_put(con); + conn_put(con); - return false; + return !rc; } void tipc_topsrv_kern_unsubscr(struct net *net, int conid) diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index a03d66046ca3..6c593788dc25 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -620,7 +620,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page, kaddr = kmap(page); iov.iov_base = kaddr + offset; iov.iov_len = size; - iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size); + iov_iter_kvec(&msg_iter, ITER_SOURCE, &iov, 1, size); iter_offset.msg_iter = &msg_iter; rc = tls_push_data(sk, iter_offset, size, flags, TLS_RECORD_TYPE_DATA, NULL); @@ -697,7 +697,7 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) union tls_iter_offset iter; struct iov_iter msg_iter; - iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0); + iov_iter_kvec(&msg_iter, ITER_SOURCE, NULL, 0, 0); iter.msg_iter = &msg_iter; return tls_push_data(sk, iter, 0, flags, TLS_RECORD_TYPE_DATA, NULL); } diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index cdb391a8754b..7fbb1d0b69b3 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -346,7 +346,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, salt = tls_ctx->crypto_send.aes_gcm_256.salt; break; default: - return NULL; + goto free_req; } cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type]; buf_len = cipher_sz->salt + cipher_sz->iv + TLS_AAD_SPACE_SIZE + @@ -492,7 +492,8 @@ int tls_sw_fallback_init(struct sock *sk, key = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->key; break; default: - return -EINVAL; + rc = -EINVAL; + goto free_aead; } cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type]; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 264cf367e265..a83d2b4275fa 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -792,7 +792,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, struct sk_psock *psock; struct sock *sk_redir; struct tls_rec *rec; - bool enospc, policy; + bool enospc, policy, redir_ingress; int err = 0, send; u32 delta = 0; @@ -837,6 +837,7 @@ more_data: } break; case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; sk_redir = psock->sk_redir; memcpy(&msg_redir, msg, sizeof(*msg)); if (msg->apply_bytes < send) @@ -846,7 +847,8 @@ more_data: sk_msg_return_zero(sk, msg, send); msg->sg.size -= send; release_sock(sk); - err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); + err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + &msg_redir, send, flags); lock_sock(sk); if (err < 0) { *copied -= sk_msg_free_nocharge(sk, &msg_redir); @@ -2425,7 +2427,7 @@ static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) { struct tls_rec *rec; - rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); + rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list); if (!rec) return false; diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index b3545fc68097..f0c2293f1d3b 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1999,13 +1999,20 @@ restart_locked: unix_state_lock(sk); err = 0; - if (unix_peer(sk) == other) { + if (sk->sk_type == SOCK_SEQPACKET) { + /* We are here only when racing with unix_release_sock() + * is clearing @other. Never change state to TCP_CLOSE + * unlike SOCK_DGRAM wants. + */ + unix_state_unlock(sk); + err = -EPIPE; + } else if (unix_peer(sk) == other) { unix_peer(sk) = NULL; unix_dgram_peer_wake_disconnect_wakeup(sk, other); + sk->sk_state = TCP_CLOSE; unix_state_unlock(sk); - sk->sk_state = TCP_CLOSE; unix_dgram_disconnected(sk, other); sock_put(other); err = -ECONNREFUSED; @@ -3738,6 +3745,7 @@ static int __init af_unix_init(void) rc = proto_register(&unix_stream_proto, 1); if (rc != 0) { pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); + proto_unregister(&unix_dgram_proto); goto out; } diff --git a/net/unix/diag.c b/net/unix/diag.c index 105f522a89fe..616b55c5b890 100644 --- a/net/unix/diag.c +++ b/net/unix/diag.c @@ -114,14 +114,16 @@ static int sk_diag_show_rqlen(struct sock *sk, struct sk_buff *nlskb) return nla_put(nlskb, UNIX_DIAG_RQLEN, sizeof(rql), &rql); } -static int sk_diag_dump_uid(struct sock *sk, struct sk_buff *nlskb) +static int sk_diag_dump_uid(struct sock *sk, struct sk_buff *nlskb, + struct user_namespace *user_ns) { - uid_t uid = from_kuid_munged(sk_user_ns(nlskb->sk), sock_i_uid(sk)); + uid_t uid = from_kuid_munged(user_ns, sock_i_uid(sk)); return nla_put(nlskb, UNIX_DIAG_UID, sizeof(uid_t), &uid); } static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, - u32 portid, u32 seq, u32 flags, int sk_ino) + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) { struct nlmsghdr *nlh; struct unix_diag_msg *rep; @@ -167,7 +169,7 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_r goto out_nlmsg_trim; if ((req->udiag_show & UDIAG_SHOW_UID) && - sk_diag_dump_uid(sk, skb)) + sk_diag_dump_uid(sk, skb, user_ns)) goto out_nlmsg_trim; nlmsg_end(skb, nlh); @@ -179,7 +181,8 @@ out_nlmsg_trim: } static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, - u32 portid, u32 seq, u32 flags) + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags) { int sk_ino; @@ -190,7 +193,7 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_r if (!sk_ino) return 0; - return sk_diag_fill(sk, skb, req, portid, seq, flags, sk_ino); + return sk_diag_fill(sk, skb, req, user_ns, portid, seq, flags, sk_ino); } static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) @@ -214,7 +217,7 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) goto next; if (!(req->udiag_states & (1 << sk->sk_state))) goto next; - if (sk_diag_dump(sk, skb, req, + if (sk_diag_dump(sk, skb, req, sk_user_ns(skb->sk), NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI) < 0) { @@ -282,7 +285,8 @@ again: if (!rep) goto out; - err = sk_diag_fill(sk, rep, req, NETLINK_CB(in_skb).portid, + err = sk_diag_fill(sk, rep, req, sk_user_ns(NETLINK_CB(in_skb).sk), + NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, 0, req->udiag_ino); if (err < 0) { nlmsg_free(rep); diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 884eca7f6743..d593d5b6d4b1 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -626,8 +626,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk, struct sockaddr_vm new_addr; if (!port) - port = LAST_RESERVED_PORT + 1 + - prandom_u32_max(U32_MAX - LAST_RESERVED_PORT); + port = get_random_u32_above(LAST_RESERVED_PORT); vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 842c94286d31..36eb16a40745 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -1711,7 +1711,11 @@ static int vmci_transport_dgram_enqueue( if (!dg) return -ENOMEM; - memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + err = memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + if (err) { + kfree(dg); + return err; + } dg->dst = vmci_make_handle(remote_addr->svm_cid, remote_addr->svm_port); diff --git a/net/wireless/core.h b/net/wireless/core.h index 775e16cb99ed..af85d8909935 100644 --- a/net/wireless/core.h +++ b/net/wireless/core.h @@ -271,6 +271,8 @@ struct cfg80211_event { } ij; struct { u8 bssid[ETH_ALEN]; + const u8 *td_bitmap; + u8 td_bitmap_len; } pa; }; }; @@ -409,7 +411,8 @@ int cfg80211_disconnect(struct cfg80211_registered_device *rdev, bool wextev); void __cfg80211_roamed(struct wireless_dev *wdev, struct cfg80211_roam_info *info); -void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid); +void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid, + const u8 *td_bitmap, u8 td_bitmap_len); int cfg80211_mgd_wext_connect(struct cfg80211_registered_device *rdev, struct wireless_dev *wdev); void cfg80211_autodisconnect_wk(struct work_struct *work); diff --git a/net/wireless/mlme.c b/net/wireless/mlme.c index 581df7f4c524..58e1fb18f85a 100644 --- a/net/wireless/mlme.c +++ b/net/wireless/mlme.c @@ -42,6 +42,10 @@ void cfg80211_rx_assoc_resp(struct net_device *dev, unsigned int link_id; for (link_id = 0; link_id < ARRAY_SIZE(data->links); link_id++) { + cr.links[link_id].status = data->links[link_id].status; + WARN_ON_ONCE(cr.links[link_id].status != WLAN_STATUS_SUCCESS && + (!cr.ap_mld_addr || !cr.links[link_id].bss)); + cr.links[link_id].bss = data->links[link_id].bss; if (!cr.links[link_id].bss) continue; diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c index 597c52236514..33a82ecab9d5 100644 --- a/net/wireless/nl80211.c +++ b/net/wireless/nl80211.c @@ -3868,6 +3868,9 @@ static int nl80211_send_iface(struct sk_buff *msg, u32 portid, u32 seq, int flag struct cfg80211_chan_def chandef = {}; int ret; + if (!link) + goto nla_put_failure; + if (nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, link_id)) goto nla_put_failure; if (nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, @@ -7780,6 +7783,7 @@ static int nl80211_set_bss(struct sk_buff *skb, struct genl_info *info) int err; memset(¶ms, 0, sizeof(params)); + params.link_id = nl80211_link_id_or_invalid(info->attrs); /* default to not changing parameters */ params.use_cts_prot = -1; params.use_short_preamble = -1; @@ -16139,7 +16143,8 @@ static u32 nl80211_internal_flags[] = { #undef SELECTOR }; -static int nl80211_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, +static int nl80211_pre_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { struct cfg80211_registered_device *rdev = NULL; @@ -16240,7 +16245,8 @@ out_unlock: return err; } -static void nl80211_post_doit(const struct genl_ops *ops, struct sk_buff *skb, +static void nl80211_post_doit(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) { u32 internal_flags = nl80211_internal_flags[ops->internal_flags]; @@ -16566,7 +16572,8 @@ static const struct genl_small_ops nl80211_small_ops[] = { .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, .doit = nl80211_set_bss, .flags = GENL_UNS_ADMIN_PERM, - .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), }, { .cmd = NL80211_CMD_GET_REG, @@ -17747,6 +17754,7 @@ void nl80211_send_connect_result(struct cfg80211_registered_device *rdev, link_info_size += (cr->links[link].bssid || cr->links[link].bss) ? nla_total_size(ETH_ALEN) : 0; + link_info_size += nla_total_size(sizeof(u16)); } } @@ -17815,7 +17823,9 @@ void nl80211_send_connect_result(struct cfg80211_registered_device *rdev, nla_put(msg, NL80211_ATTR_BSSID, ETH_ALEN, bssid)) || (cr->links[link].addr && nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, - cr->links[link].addr))) + cr->links[link].addr)) || + nla_put_u16(msg, NL80211_ATTR_STATUS_CODE, + cr->links[link].status)) goto nla_put_failure; nla_nest_end(msg, nested_mlo_links); @@ -17939,7 +17949,8 @@ void nl80211_send_roamed(struct cfg80211_registered_device *rdev, } void nl80211_send_port_authorized(struct cfg80211_registered_device *rdev, - struct net_device *netdev, const u8 *bssid) + struct net_device *netdev, const u8 *bssid, + const u8 *td_bitmap, u8 td_bitmap_len) { struct sk_buff *msg; void *hdr; @@ -17959,6 +17970,11 @@ void nl80211_send_port_authorized(struct cfg80211_registered_device *rdev, nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, bssid)) goto nla_put_failure; + if ((td_bitmap_len > 0) && td_bitmap) + if (nla_put(msg, NL80211_ATTR_TD_BITMAP, + td_bitmap_len, td_bitmap)) + goto nla_put_failure; + genlmsg_end(msg, hdr); genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, diff --git a/net/wireless/nl80211.h b/net/wireless/nl80211.h index 855d540ddfb9..ba9457e94c43 100644 --- a/net/wireless/nl80211.h +++ b/net/wireless/nl80211.h @@ -83,7 +83,8 @@ void nl80211_send_roamed(struct cfg80211_registered_device *rdev, struct net_device *netdev, struct cfg80211_roam_info *info, gfp_t gfp); void nl80211_send_port_authorized(struct cfg80211_registered_device *rdev, - struct net_device *netdev, const u8 *bssid); + struct net_device *netdev, const u8 *bssid, + const u8 *td_bitmap, u8 td_bitmap_len); void nl80211_send_disconnected(struct cfg80211_registered_device *rdev, struct net_device *netdev, u16 reason, const u8 *ie, size_t ie_len, bool from_ap); diff --git a/net/wireless/reg.c b/net/wireless/reg.c index c3d950d29432..4f3f31244e8b 100644 --- a/net/wireless/reg.c +++ b/net/wireless/reg.c @@ -4311,8 +4311,10 @@ static int __init regulatory_init_db(void) return -EINVAL; err = load_builtin_regdb_keys(); - if (err) + if (err) { + platform_device_unregister(reg_pdev); return err; + } /* We always try to get an update for the static regdomain */ err = regulatory_hint_core(cfg80211_world_regdom->alpha2); diff --git a/net/wireless/scan.c b/net/wireless/scan.c index da752b0cc752..790bc31cf82e 100644 --- a/net/wireless/scan.c +++ b/net/wireless/scan.c @@ -158,9 +158,8 @@ static inline void bss_ref_put(struct cfg80211_registered_device *rdev, if (bss->pub.hidden_beacon_bss) { struct cfg80211_internal_bss *hbss; - hbss = container_of(bss->pub.hidden_beacon_bss, - struct cfg80211_internal_bss, - pub); + + hbss = bss_from_pub(bss->pub.hidden_beacon_bss); hbss->refcount--; if (hbss->refcount == 0) bss_free(hbss); @@ -169,9 +168,7 @@ static inline void bss_ref_put(struct cfg80211_registered_device *rdev, if (bss->pub.transmitted_bss) { struct cfg80211_internal_bss *tbss; - tbss = container_of(bss->pub.transmitted_bss, - struct cfg80211_internal_bss, - pub); + tbss = bss_from_pub(bss->pub.transmitted_bss); tbss->refcount--; if (tbss->refcount == 0) bss_free(tbss); @@ -330,7 +327,8 @@ static size_t cfg80211_gen_new_ie(const u8 *ie, size_t ielen, * determine if they are the same ie. */ if (tmp_old[0] == WLAN_EID_VENDOR_SPECIFIC) { - if (!memcmp(tmp_old + 2, tmp + 2, 5)) { + if (tmp_old[1] >= 5 && tmp[1] >= 5 && + !memcmp(tmp_old + 2, tmp + 2, 5)) { /* same vendor ie, copy from * subelement */ @@ -1289,7 +1287,8 @@ static int cmp_bss(struct cfg80211_bss *a, int i, r; if (a->channel != b->channel) - return b->channel->center_freq - a->channel->center_freq; + return (b->channel->center_freq * 1000 + b->channel->freq_offset) - + (a->channel->center_freq * 1000 + a->channel->freq_offset); a_ies = rcu_access_pointer(a->ies); if (!a_ies) @@ -1790,13 +1789,8 @@ cfg80211_bss_update(struct cfg80211_registered_device *rdev, /* This must be before the call to bss_ref_get */ if (tmp->pub.transmitted_bss) { - struct cfg80211_internal_bss *pbss = - container_of(tmp->pub.transmitted_bss, - struct cfg80211_internal_bss, - pub); - new->pub.transmitted_bss = tmp->pub.transmitted_bss; - bss_ref_get(rdev, pbss); + bss_ref_get(rdev, bss_from_pub(tmp->pub.transmitted_bss)); } list_add_tail(&new->list, &rdev->bss_list); @@ -2526,10 +2520,15 @@ cfg80211_inform_bss_frame_data(struct wiphy *wiphy, const struct cfg80211_bss_ies *ies1, *ies2; size_t ielen = len - offsetof(struct ieee80211_mgmt, u.probe_resp.variable); - struct cfg80211_non_tx_bss non_tx_data; + struct cfg80211_non_tx_bss non_tx_data = {}; res = cfg80211_inform_single_bss_frame_data(wiphy, data, mgmt, len, gfp); + + /* don't do any further MBSSID handling for S1G */ + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) + return res; + if (!res || !wiphy->support_mbssid || !cfg80211_find_elem(WLAN_EID_MULTIPLE_BSSID, ie, ielen)) return res; @@ -2569,15 +2568,12 @@ EXPORT_SYMBOL(cfg80211_inform_bss_frame_data); void cfg80211_ref_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) { struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); - struct cfg80211_internal_bss *bss; if (!pub) return; - bss = container_of(pub, struct cfg80211_internal_bss, pub); - spin_lock_bh(&rdev->bss_lock); - bss_ref_get(rdev, bss); + bss_ref_get(rdev, bss_from_pub(pub)); spin_unlock_bh(&rdev->bss_lock); } EXPORT_SYMBOL(cfg80211_ref_bss); @@ -2585,15 +2581,12 @@ EXPORT_SYMBOL(cfg80211_ref_bss); void cfg80211_put_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) { struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); - struct cfg80211_internal_bss *bss; if (!pub) return; - bss = container_of(pub, struct cfg80211_internal_bss, pub); - spin_lock_bh(&rdev->bss_lock); - bss_ref_put(rdev, bss); + bss_ref_put(rdev, bss_from_pub(pub)); spin_unlock_bh(&rdev->bss_lock); } EXPORT_SYMBOL(cfg80211_put_bss); @@ -2607,7 +2600,7 @@ void cfg80211_unlink_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) if (WARN_ON(!pub)) return; - bss = container_of(pub, struct cfg80211_internal_bss, pub); + bss = bss_from_pub(pub); spin_lock_bh(&rdev->bss_lock); if (list_empty(&bss->list)) @@ -2616,8 +2609,7 @@ void cfg80211_unlink_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) list_for_each_entry_safe(nontrans_bss, tmp, &pub->nontrans_list, nontrans_list) { - tmp1 = container_of(nontrans_bss, - struct cfg80211_internal_bss, pub); + tmp1 = bss_from_pub(nontrans_bss); if (__cfg80211_unlink_bss(rdev, tmp1)) rdev->bss_generation++; } @@ -2674,9 +2666,7 @@ void cfg80211_update_assoc_bss_entry(struct wireless_dev *wdev, /* use transmitting bss */ if (cbss->pub.transmitted_bss) - cbss = container_of(cbss->pub.transmitted_bss, - struct cfg80211_internal_bss, - pub); + cbss = bss_from_pub(cbss->pub.transmitted_bss); cbss->pub.channel = chan; @@ -2705,8 +2695,7 @@ void cfg80211_update_assoc_bss_entry(struct wireless_dev *wdev, list_for_each_entry_safe(nontrans_bss, tmp, &new->pub.nontrans_list, nontrans_list) { - bss = container_of(nontrans_bss, - struct cfg80211_internal_bss, pub); + bss = bss_from_pub(nontrans_bss); if (__cfg80211_unlink_bss(rdev, bss)) rdev->bss_generation++; } @@ -2723,8 +2712,7 @@ void cfg80211_update_assoc_bss_entry(struct wireless_dev *wdev, list_for_each_entry_safe(nontrans_bss, tmp, &cbss->pub.nontrans_list, nontrans_list) { - bss = container_of(nontrans_bss, - struct cfg80211_internal_bss, pub); + bss = bss_from_pub(nontrans_bss); bss->pub.channel = chan; rb_erase(&bss->rbn, &rdev->bss_tree); rb_insert_bss(rdev, bss); @@ -3231,8 +3219,9 @@ static int ieee80211_scan_results(struct cfg80211_registered_device *rdev, int cfg80211_wext_giwscan(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_point *data = &wrqu->data; struct cfg80211_registered_device *rdev; int res; diff --git a/net/wireless/sme.c b/net/wireless/sme.c index d513536617bd..4b5b6ee0fe01 100644 --- a/net/wireless/sme.c +++ b/net/wireless/sme.c @@ -793,6 +793,10 @@ void __cfg80211_connect_result(struct net_device *dev, } for_each_valid_link(cr, link) { + /* don't do extra lookups for failures */ + if (cr->links[link].status != WLAN_STATUS_SUCCESS) + continue; + if (cr->links[link].bss) continue; @@ -829,6 +833,16 @@ void __cfg80211_connect_result(struct net_device *dev, } memset(wdev->links, 0, sizeof(wdev->links)); + for_each_valid_link(cr, link) { + if (cr->links[link].status == WLAN_STATUS_SUCCESS) + continue; + cr->valid_links &= ~BIT(link); + /* don't require bss pointer for failed links */ + if (!cr->links[link].bss) + continue; + cfg80211_unhold_bss(bss_from_pub(cr->links[link].bss)); + cfg80211_put_bss(wdev->wiphy, cr->links[link].bss); + } wdev->valid_links = cr->valid_links; for_each_valid_link(cr, link) wdev->links[link].client.current_bss = @@ -1237,7 +1251,8 @@ out: } EXPORT_SYMBOL(cfg80211_roamed); -void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid) +void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid, + const u8 *td_bitmap, u8 td_bitmap_len) { ASSERT_WDEV_LOCK(wdev); @@ -1250,11 +1265,11 @@ void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid) return; nl80211_send_port_authorized(wiphy_to_rdev(wdev->wiphy), wdev->netdev, - bssid); + bssid, td_bitmap, td_bitmap_len); } void cfg80211_port_authorized(struct net_device *dev, const u8 *bssid, - gfp_t gfp) + const u8 *td_bitmap, u8 td_bitmap_len, gfp_t gfp) { struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); @@ -1264,12 +1279,15 @@ void cfg80211_port_authorized(struct net_device *dev, const u8 *bssid, if (WARN_ON(!bssid)) return; - ev = kzalloc(sizeof(*ev), gfp); + ev = kzalloc(sizeof(*ev) + td_bitmap_len, gfp); if (!ev) return; ev->type = EVENT_PORT_AUTHORIZED; memcpy(ev->pa.bssid, bssid, ETH_ALEN); + ev->pa.td_bitmap = ((u8 *)ev) + sizeof(*ev); + ev->pa.td_bitmap_len = td_bitmap_len; + memcpy((void *)ev->pa.td_bitmap, td_bitmap, td_bitmap_len); /* * Use the wdev event list so that if there are pending diff --git a/net/wireless/sysfs.c b/net/wireless/sysfs.c index 0c3f05c9be27..cdb638647e0b 100644 --- a/net/wireless/sysfs.c +++ b/net/wireless/sysfs.c @@ -148,7 +148,7 @@ static SIMPLE_DEV_PM_OPS(wiphy_pm_ops, wiphy_suspend, wiphy_resume); #define WIPHY_PM_OPS NULL #endif -static const void *wiphy_namespace(struct device *d) +static const void *wiphy_namespace(const struct device *d) { struct wiphy *wiphy = container_of(d, struct wiphy, dev); diff --git a/net/wireless/util.c b/net/wireless/util.c index 39680e7bad45..8f403f9fe816 100644 --- a/net/wireless/util.c +++ b/net/wireless/util.c @@ -990,7 +990,9 @@ void cfg80211_process_wdev_events(struct wireless_dev *wdev) __cfg80211_leave(wiphy_to_rdev(wdev->wiphy), wdev); break; case EVENT_PORT_AUTHORIZED: - __cfg80211_port_authorized(wdev, ev->pa.bssid); + __cfg80211_port_authorized(wdev, ev->pa.bssid, + ev->pa.td_bitmap, + ev->pa.td_bitmap_len); break; } wdev_unlock(wdev); diff --git a/net/wireless/wext-compat.c b/net/wireless/wext-compat.c index ddf340bfa07a..8a24dfca75af 100644 --- a/net/wireless/wext-compat.c +++ b/net/wireless/wext-compat.c @@ -25,16 +25,17 @@ int cfg80211_wext_giwname(struct net_device *dev, struct iw_request_info *info, - char *name, char *extra) + union iwreq_data *wrqu, char *extra) { - strcpy(name, "IEEE 802.11"); + strcpy(wrqu->name, "IEEE 802.11"); return 0; } EXPORT_WEXT_HANDLER(cfg80211_wext_giwname); int cfg80211_wext_siwmode(struct net_device *dev, struct iw_request_info *info, - u32 *mode, char *extra) + union iwreq_data *wrqu, char *extra) { + __u32 *mode = &wrqu->mode; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev; struct vif_params vifparams; @@ -71,8 +72,9 @@ int cfg80211_wext_siwmode(struct net_device *dev, struct iw_request_info *info, EXPORT_WEXT_HANDLER(cfg80211_wext_siwmode); int cfg80211_wext_giwmode(struct net_device *dev, struct iw_request_info *info, - u32 *mode, char *extra) + union iwreq_data *wrqu, char *extra) { + __u32 *mode = &wrqu->mode; struct wireless_dev *wdev = dev->ieee80211_ptr; if (!wdev) @@ -108,8 +110,9 @@ EXPORT_WEXT_HANDLER(cfg80211_wext_giwmode); int cfg80211_wext_giwrange(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_point *data = &wrqu->data; struct wireless_dev *wdev = dev->ieee80211_ptr; struct iw_range *range = (struct iw_range *) extra; enum nl80211_band band; @@ -251,8 +254,9 @@ int cfg80211_wext_freq(struct iw_freq *freq) int cfg80211_wext_siwrts(struct net_device *dev, struct iw_request_info *info, - struct iw_param *rts, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *rts = &wrqu->rts; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); u32 orts = wdev->wiphy->rts_threshold; @@ -281,8 +285,9 @@ EXPORT_WEXT_HANDLER(cfg80211_wext_siwrts); int cfg80211_wext_giwrts(struct net_device *dev, struct iw_request_info *info, - struct iw_param *rts, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *rts = &wrqu->rts; struct wireless_dev *wdev = dev->ieee80211_ptr; rts->value = wdev->wiphy->rts_threshold; @@ -295,8 +300,9 @@ EXPORT_WEXT_HANDLER(cfg80211_wext_giwrts); int cfg80211_wext_siwfrag(struct net_device *dev, struct iw_request_info *info, - struct iw_param *frag, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *frag = &wrqu->frag; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); u32 ofrag = wdev->wiphy->frag_threshold; @@ -325,8 +331,9 @@ EXPORT_WEXT_HANDLER(cfg80211_wext_siwfrag); int cfg80211_wext_giwfrag(struct net_device *dev, struct iw_request_info *info, - struct iw_param *frag, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *frag = &wrqu->frag; struct wireless_dev *wdev = dev->ieee80211_ptr; frag->value = wdev->wiphy->frag_threshold; @@ -339,8 +346,9 @@ EXPORT_WEXT_HANDLER(cfg80211_wext_giwfrag); static int cfg80211_wext_siwretry(struct net_device *dev, struct iw_request_info *info, - struct iw_param *retry, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *retry = &wrqu->retry; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); u32 changed = 0; @@ -378,8 +386,9 @@ static int cfg80211_wext_siwretry(struct net_device *dev, int cfg80211_wext_giwretry(struct net_device *dev, struct iw_request_info *info, - struct iw_param *retry, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *retry = &wrqu->retry; struct wireless_dev *wdev = dev->ieee80211_ptr; retry->disabled = 0; @@ -588,8 +597,9 @@ static int cfg80211_set_encryption(struct cfg80211_registered_device *rdev, static int cfg80211_wext_siwencode(struct net_device *dev, struct iw_request_info *info, - struct iw_point *erq, char *keybuf) + union iwreq_data *wrqu, char *keybuf) { + struct iw_point *erq = &wrqu->encoding; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); int idx, err; @@ -664,8 +674,9 @@ out: static int cfg80211_wext_siwencodeext(struct net_device *dev, struct iw_request_info *info, - struct iw_point *erq, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_point *erq = &wrqu->encoding; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct iw_encode_ext *ext = (struct iw_encode_ext *) extra; @@ -767,8 +778,9 @@ static int cfg80211_wext_siwencodeext(struct net_device *dev, static int cfg80211_wext_giwencode(struct net_device *dev, struct iw_request_info *info, - struct iw_point *erq, char *keybuf) + union iwreq_data *wrqu, char *keybuf) { + struct iw_point *erq = &wrqu->encoding; struct wireless_dev *wdev = dev->ieee80211_ptr; int idx; @@ -804,8 +816,9 @@ static int cfg80211_wext_giwencode(struct net_device *dev, static int cfg80211_wext_siwfreq(struct net_device *dev, struct iw_request_info *info, - struct iw_freq *wextfreq, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_freq *wextfreq = &wrqu->freq; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct cfg80211_chan_def chandef = { @@ -870,8 +883,9 @@ static int cfg80211_wext_siwfreq(struct net_device *dev, static int cfg80211_wext_giwfreq(struct net_device *dev, struct iw_request_info *info, - struct iw_freq *freq, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_freq *freq = &wrqu->freq; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct cfg80211_chan_def chandef = {}; @@ -1147,8 +1161,9 @@ static int cfg80211_set_key_mgt(struct wireless_dev *wdev, u32 key_mgt) static int cfg80211_wext_siwauth(struct net_device *dev, struct iw_request_info *info, - struct iw_param *data, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *data = &wrqu->param; struct wireless_dev *wdev = dev->ieee80211_ptr; if (wdev->iftype != NL80211_IFTYPE_STATION) @@ -1180,7 +1195,7 @@ static int cfg80211_wext_siwauth(struct net_device *dev, static int cfg80211_wext_giwauth(struct net_device *dev, struct iw_request_info *info, - struct iw_param *data, char *extra) + union iwreq_data *wrqu, char *extra) { /* XXX: what do we need? */ @@ -1189,8 +1204,9 @@ static int cfg80211_wext_giwauth(struct net_device *dev, static int cfg80211_wext_siwpower(struct net_device *dev, struct iw_request_info *info, - struct iw_param *wrq, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *wrq = &wrqu->power; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); bool ps; @@ -1238,8 +1254,9 @@ static int cfg80211_wext_siwpower(struct net_device *dev, static int cfg80211_wext_giwpower(struct net_device *dev, struct iw_request_info *info, - struct iw_param *wrq, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *wrq = &wrqu->power; struct wireless_dev *wdev = dev->ieee80211_ptr; wrq->disabled = !wdev->ps; @@ -1249,8 +1266,9 @@ static int cfg80211_wext_giwpower(struct net_device *dev, static int cfg80211_wext_siwrate(struct net_device *dev, struct iw_request_info *info, - struct iw_param *rate, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *rate = &wrqu->bitrate; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct cfg80211_bitrate_mask mask; @@ -1307,8 +1325,9 @@ static int cfg80211_wext_siwrate(struct net_device *dev, static int cfg80211_wext_giwrate(struct net_device *dev, struct iw_request_info *info, - struct iw_param *rate, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_param *rate = &wrqu->bitrate; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct station_info sinfo = {}; @@ -1430,8 +1449,9 @@ static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) static int cfg80211_wext_siwap(struct net_device *dev, struct iw_request_info *info, - struct sockaddr *ap_addr, char *extra) + union iwreq_data *wrqu, char *extra) { + struct sockaddr *ap_addr = &wrqu->ap_addr; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); int ret; @@ -1455,8 +1475,9 @@ static int cfg80211_wext_siwap(struct net_device *dev, static int cfg80211_wext_giwap(struct net_device *dev, struct iw_request_info *info, - struct sockaddr *ap_addr, char *extra) + union iwreq_data *wrqu, char *extra) { + struct sockaddr *ap_addr = &wrqu->ap_addr; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); int ret; @@ -1480,8 +1501,9 @@ static int cfg80211_wext_giwap(struct net_device *dev, static int cfg80211_wext_siwessid(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *ssid) + union iwreq_data *wrqu, char *ssid) { + struct iw_point *data = &wrqu->data; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); int ret; @@ -1505,8 +1527,9 @@ static int cfg80211_wext_siwessid(struct net_device *dev, static int cfg80211_wext_giwessid(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *ssid) + union iwreq_data *wrqu, char *ssid) { + struct iw_point *data = &wrqu->data; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); int ret; @@ -1533,7 +1556,7 @@ static int cfg80211_wext_giwessid(struct net_device *dev, static int cfg80211_wext_siwpmksa(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra) + union iwreq_data *wrqu, char *extra) { struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); @@ -1584,78 +1607,39 @@ static int cfg80211_wext_siwpmksa(struct net_device *dev, return ret; } -#define DEFINE_WEXT_COMPAT_STUB(func, type) \ - static int __ ## func(struct net_device *dev, \ - struct iw_request_info *info, \ - union iwreq_data *wrqu, \ - char *extra) \ - { \ - return func(dev, info, (type *)wrqu, extra); \ - } - -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwname, char) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwfreq, struct iw_freq) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwfreq, struct iw_freq) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwmode, u32) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwmode, u32) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrange, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwap, struct sockaddr) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwap, struct sockaddr) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwmlme, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwscan, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwessid, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwessid, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwrate, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrate, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwrts, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrts, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwfrag, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwfrag, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwretry, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwretry, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwencode, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwencode, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwpower, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwpower, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwgenie, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwauth, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwauth, struct iw_param) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwencodeext, struct iw_point) -DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwpmksa, struct iw_point) - static const iw_handler cfg80211_handlers[] = { - [IW_IOCTL_IDX(SIOCGIWNAME)] = __cfg80211_wext_giwname, - [IW_IOCTL_IDX(SIOCSIWFREQ)] = __cfg80211_wext_siwfreq, - [IW_IOCTL_IDX(SIOCGIWFREQ)] = __cfg80211_wext_giwfreq, - [IW_IOCTL_IDX(SIOCSIWMODE)] = __cfg80211_wext_siwmode, - [IW_IOCTL_IDX(SIOCGIWMODE)] = __cfg80211_wext_giwmode, - [IW_IOCTL_IDX(SIOCGIWRANGE)] = __cfg80211_wext_giwrange, - [IW_IOCTL_IDX(SIOCSIWAP)] = __cfg80211_wext_siwap, - [IW_IOCTL_IDX(SIOCGIWAP)] = __cfg80211_wext_giwap, - [IW_IOCTL_IDX(SIOCSIWMLME)] = __cfg80211_wext_siwmlme, - [IW_IOCTL_IDX(SIOCSIWSCAN)] = cfg80211_wext_siwscan, - [IW_IOCTL_IDX(SIOCGIWSCAN)] = __cfg80211_wext_giwscan, - [IW_IOCTL_IDX(SIOCSIWESSID)] = __cfg80211_wext_siwessid, - [IW_IOCTL_IDX(SIOCGIWESSID)] = __cfg80211_wext_giwessid, - [IW_IOCTL_IDX(SIOCSIWRATE)] = __cfg80211_wext_siwrate, - [IW_IOCTL_IDX(SIOCGIWRATE)] = __cfg80211_wext_giwrate, - [IW_IOCTL_IDX(SIOCSIWRTS)] = __cfg80211_wext_siwrts, - [IW_IOCTL_IDX(SIOCGIWRTS)] = __cfg80211_wext_giwrts, - [IW_IOCTL_IDX(SIOCSIWFRAG)] = __cfg80211_wext_siwfrag, - [IW_IOCTL_IDX(SIOCGIWFRAG)] = __cfg80211_wext_giwfrag, - [IW_IOCTL_IDX(SIOCSIWTXPOW)] = cfg80211_wext_siwtxpower, - [IW_IOCTL_IDX(SIOCGIWTXPOW)] = cfg80211_wext_giwtxpower, - [IW_IOCTL_IDX(SIOCSIWRETRY)] = __cfg80211_wext_siwretry, - [IW_IOCTL_IDX(SIOCGIWRETRY)] = __cfg80211_wext_giwretry, - [IW_IOCTL_IDX(SIOCSIWENCODE)] = __cfg80211_wext_siwencode, - [IW_IOCTL_IDX(SIOCGIWENCODE)] = __cfg80211_wext_giwencode, - [IW_IOCTL_IDX(SIOCSIWPOWER)] = __cfg80211_wext_siwpower, - [IW_IOCTL_IDX(SIOCGIWPOWER)] = __cfg80211_wext_giwpower, - [IW_IOCTL_IDX(SIOCSIWGENIE)] = __cfg80211_wext_siwgenie, - [IW_IOCTL_IDX(SIOCSIWAUTH)] = __cfg80211_wext_siwauth, - [IW_IOCTL_IDX(SIOCGIWAUTH)] = __cfg80211_wext_giwauth, - [IW_IOCTL_IDX(SIOCSIWENCODEEXT)]= __cfg80211_wext_siwencodeext, - [IW_IOCTL_IDX(SIOCSIWPMKSA)] = __cfg80211_wext_siwpmksa, + IW_HANDLER(SIOCGIWNAME, cfg80211_wext_giwname), + IW_HANDLER(SIOCSIWFREQ, cfg80211_wext_siwfreq), + IW_HANDLER(SIOCGIWFREQ, cfg80211_wext_giwfreq), + IW_HANDLER(SIOCSIWMODE, cfg80211_wext_siwmode), + IW_HANDLER(SIOCGIWMODE, cfg80211_wext_giwmode), + IW_HANDLER(SIOCGIWRANGE, cfg80211_wext_giwrange), + IW_HANDLER(SIOCSIWAP, cfg80211_wext_siwap), + IW_HANDLER(SIOCGIWAP, cfg80211_wext_giwap), + IW_HANDLER(SIOCSIWMLME, cfg80211_wext_siwmlme), + IW_HANDLER(SIOCSIWSCAN, cfg80211_wext_siwscan), + IW_HANDLER(SIOCGIWSCAN, cfg80211_wext_giwscan), + IW_HANDLER(SIOCSIWESSID, cfg80211_wext_siwessid), + IW_HANDLER(SIOCGIWESSID, cfg80211_wext_giwessid), + IW_HANDLER(SIOCSIWRATE, cfg80211_wext_siwrate), + IW_HANDLER(SIOCGIWRATE, cfg80211_wext_giwrate), + IW_HANDLER(SIOCSIWRTS, cfg80211_wext_siwrts), + IW_HANDLER(SIOCGIWRTS, cfg80211_wext_giwrts), + IW_HANDLER(SIOCSIWFRAG, cfg80211_wext_siwfrag), + IW_HANDLER(SIOCGIWFRAG, cfg80211_wext_giwfrag), + IW_HANDLER(SIOCSIWTXPOW, cfg80211_wext_siwtxpower), + IW_HANDLER(SIOCGIWTXPOW, cfg80211_wext_giwtxpower), + IW_HANDLER(SIOCSIWRETRY, cfg80211_wext_siwretry), + IW_HANDLER(SIOCGIWRETRY, cfg80211_wext_giwretry), + IW_HANDLER(SIOCSIWENCODE, cfg80211_wext_siwencode), + IW_HANDLER(SIOCGIWENCODE, cfg80211_wext_giwencode), + IW_HANDLER(SIOCSIWPOWER, cfg80211_wext_siwpower), + IW_HANDLER(SIOCGIWPOWER, cfg80211_wext_giwpower), + IW_HANDLER(SIOCSIWGENIE, cfg80211_wext_siwgenie), + IW_HANDLER(SIOCSIWAUTH, cfg80211_wext_siwauth), + IW_HANDLER(SIOCGIWAUTH, cfg80211_wext_giwauth), + IW_HANDLER(SIOCSIWENCODEEXT, cfg80211_wext_siwencodeext), + IW_HANDLER(SIOCSIWPMKSA, cfg80211_wext_siwpmksa), }; const struct iw_handler_def cfg80211_wext_handler = { diff --git a/net/wireless/wext-compat.h b/net/wireless/wext-compat.h index 8d3cc1552e2f..c02eb789e676 100644 --- a/net/wireless/wext-compat.h +++ b/net/wireless/wext-compat.h @@ -13,7 +13,7 @@ int cfg80211_ibss_wext_siwfreq(struct net_device *dev, struct iw_request_info *info, - struct iw_freq *freq, char *extra); + struct iw_freq *wextfreq, char *extra); int cfg80211_ibss_wext_giwfreq(struct net_device *dev, struct iw_request_info *info, struct iw_freq *freq, char *extra); @@ -32,7 +32,7 @@ int cfg80211_ibss_wext_giwessid(struct net_device *dev, int cfg80211_mgd_wext_siwfreq(struct net_device *dev, struct iw_request_info *info, - struct iw_freq *freq, char *extra); + struct iw_freq *wextfreq, char *extra); int cfg80211_mgd_wext_giwfreq(struct net_device *dev, struct iw_request_info *info, struct iw_freq *freq, char *extra); @@ -51,10 +51,10 @@ int cfg80211_mgd_wext_giwessid(struct net_device *dev, int cfg80211_wext_siwmlme(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra); + union iwreq_data *wrqu, char *extra); int cfg80211_wext_siwgenie(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra); + union iwreq_data *wrqu, char *extra); int cfg80211_wext_freq(struct iw_freq *freq); diff --git a/net/wireless/wext-sme.c b/net/wireless/wext-sme.c index 68f45afc352d..191c6d98c700 100644 --- a/net/wireless/wext-sme.c +++ b/net/wireless/wext-sme.c @@ -324,8 +324,9 @@ int cfg80211_mgd_wext_giwap(struct net_device *dev, int cfg80211_wext_siwgenie(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra) + union iwreq_data *wrqu, char *extra) { + struct iw_point *data = &wrqu->data; struct wireless_dev *wdev = dev->ieee80211_ptr; struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); u8 *ie = extra; @@ -374,7 +375,7 @@ int cfg80211_wext_siwgenie(struct net_device *dev, int cfg80211_wext_siwmlme(struct net_device *dev, struct iw_request_info *info, - struct iw_point *data, char *extra) + union iwreq_data *wrqu, char *extra) { struct wireless_dev *wdev = dev->ieee80211_ptr; struct iw_mlme *mlme = (struct iw_mlme *)extra; diff --git a/net/x25/af_x25.c b/net/x25/af_x25.c index 3b55502b2965..5c7ad301d742 100644 --- a/net/x25/af_x25.c +++ b/net/x25/af_x25.c @@ -482,6 +482,12 @@ static int x25_listen(struct socket *sock, int backlog) int rc = -EOPNOTSUPP; lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + rc = -EINVAL; + release_sock(sk); + return rc; + } + if (sk->sk_state != TCP_LISTEN) { memset(&x25_sk(sk)->dest_addr, 0, X25_ADDR_LEN); sk->sk_max_ack_backlog = backlog; diff --git a/net/x25/x25_dev.c b/net/x25/x25_dev.c index 5259ef8f5242..748d8630ab58 100644 --- a/net/x25/x25_dev.c +++ b/net/x25/x25_dev.c @@ -117,7 +117,7 @@ int x25_lapb_receive_frame(struct sk_buff *skb, struct net_device *dev, if (!pskb_may_pull(skb, 1)) { x25_neigh_put(nb); - return 0; + goto drop; } switch (skb->data[0]) { diff --git a/net/xdp/xskmap.c b/net/xdp/xskmap.c index acc8e52a4f5f..771d0fa90ef5 100644 --- a/net/xdp/xskmap.c +++ b/net/xdp/xskmap.c @@ -231,9 +231,9 @@ static int xsk_map_delete_elem(struct bpf_map *map, void *key) return 0; } -static int xsk_map_redirect(struct bpf_map *map, u32 ifindex, u64 flags) +static int xsk_map_redirect(struct bpf_map *map, u64 index, u64 flags) { - return __bpf_xdp_redirect_map(map, ifindex, flags, 0, + return __bpf_xdp_redirect_map(map, index, flags, 0, __xsk_map_lookup_elem); } diff --git a/net/xfrm/Makefile b/net/xfrm/Makefile index 494aa744bfb9..cd47f88921f5 100644 --- a/net/xfrm/Makefile +++ b/net/xfrm/Makefile @@ -3,6 +3,14 @@ # Makefile for the XFRM subsystem. # +xfrm_interface-$(CONFIG_XFRM_INTERFACE) += xfrm_interface_core.o + +ifeq ($(CONFIG_XFRM_INTERFACE),m) +xfrm_interface-$(CONFIG_DEBUG_INFO_BTF_MODULES) += xfrm_interface_bpf.o +else ifeq ($(CONFIG_XFRM_INTERFACE),y) +xfrm_interface-$(CONFIG_DEBUG_INFO_BTF) += xfrm_interface_bpf.o +endif + obj-$(CONFIG_XFRM) := xfrm_policy.o xfrm_state.o xfrm_hash.o \ xfrm_input.o xfrm_output.o \ xfrm_sysctl.o xfrm_replay.o xfrm_device.o diff --git a/net/xfrm/espintcp.c b/net/xfrm/espintcp.c index 29a540dcb5a7..74a54295c164 100644 --- a/net/xfrm/espintcp.c +++ b/net/xfrm/espintcp.c @@ -354,7 +354,7 @@ static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) *((__be16 *)buf) = cpu_to_be16(msglen); pfx_iov.iov_base = buf; pfx_iov.iov_len = sizeof(buf); - iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len); + iov_iter_kvec(&pfx_iter, ITER_SOURCE, &pfx_iov, 1, pfx_iov.iov_len); err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, pfx_iov.iov_len); @@ -489,6 +489,7 @@ static int espintcp_init_sk(struct sock *sk) /* avoid using task_frag */ sk->sk_allocation = GFP_ATOMIC; + sk->sk_use_task_frag = false; return 0; diff --git a/net/xfrm/xfrm_device.c b/net/xfrm/xfrm_device.c index 5f5aafd418af..4aff76c6f12e 100644 --- a/net/xfrm/xfrm_device.c +++ b/net/xfrm/xfrm_device.c @@ -97,6 +97,18 @@ static void xfrm_outer_mode_prep(struct xfrm_state *x, struct sk_buff *skb) } } +static inline bool xmit_xfrm_check_overflow(struct sk_buff *skb) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + __u32 seq = xo->seq.low; + + seq += skb_shinfo(skb)->gso_segs; + if (unlikely(seq < xo->seq.low)) + return true; + + return false; +} + struct sk_buff *validate_xmit_xfrm(struct sk_buff *skb, netdev_features_t features, bool *again) { int err; @@ -120,6 +132,16 @@ struct sk_buff *validate_xmit_xfrm(struct sk_buff *skb, netdev_features_t featur if (xo->flags & XFRM_GRO || x->xso.dir == XFRM_DEV_OFFLOAD_IN) return skb; + /* The packet was sent to HW IPsec packet offload engine, + * but to wrong device. Drop the packet, so it won't skip + * XFRM stack. + */ + if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET && x->xso.dev != dev) { + kfree_skb(skb); + dev_core_stats_tx_dropped_inc(dev); + return NULL; + } + /* This skb was already validated on the upper/virtual dev */ if ((x->xso.dev != dev) && (x->xso.real_dev == dev)) return skb; @@ -134,7 +156,8 @@ struct sk_buff *validate_xmit_xfrm(struct sk_buff *skb, netdev_features_t featur return skb; } - if (skb_is_gso(skb) && unlikely(x->xso.dev != dev)) { + if (skb_is_gso(skb) && (unlikely(x->xso.dev != dev) || + unlikely(xmit_xfrm_check_overflow(skb)))) { struct sk_buff *segs; /* Packet got rerouted, fixup features and segment it. */ @@ -216,6 +239,7 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, struct xfrm_dev_offload *xso = &x->xso; xfrm_address_t *saddr; xfrm_address_t *daddr; + bool is_packet_offload; if (!x->type_offload) { NL_SET_ERR_MSG(extack, "Type doesn't support offload"); @@ -228,11 +252,13 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, return -EINVAL; } - if (xuo->flags & ~(XFRM_OFFLOAD_IPV6 | XFRM_OFFLOAD_INBOUND)) { + if (xuo->flags & + ~(XFRM_OFFLOAD_IPV6 | XFRM_OFFLOAD_INBOUND | XFRM_OFFLOAD_PACKET)) { NL_SET_ERR_MSG(extack, "Unrecognized flags in offload request"); return -EINVAL; } + is_packet_offload = xuo->flags & XFRM_OFFLOAD_PACKET; dev = dev_get_by_index(net, xuo->ifindex); if (!dev) { if (!(xuo->flags & XFRM_OFFLOAD_INBOUND)) { @@ -247,7 +273,7 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, x->props.family, xfrm_smark_get(0, x)); if (IS_ERR(dst)) - return 0; + return (is_packet_offload) ? -EINVAL : 0; dev = dst->dev; @@ -258,7 +284,7 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, if (!dev->xfrmdev_ops || !dev->xfrmdev_ops->xdo_dev_state_add) { xso->dev = NULL; dev_put(dev); - return 0; + return (is_packet_offload) ? -EINVAL : 0; } if (x->props.flags & XFRM_STATE_ESN && @@ -278,14 +304,28 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, else xso->dir = XFRM_DEV_OFFLOAD_OUT; + if (is_packet_offload) + xso->type = XFRM_DEV_OFFLOAD_PACKET; + else + xso->type = XFRM_DEV_OFFLOAD_CRYPTO; + err = dev->xfrmdev_ops->xdo_dev_state_add(x); if (err) { xso->dev = NULL; xso->dir = 0; xso->real_dev = NULL; netdev_put(dev, &xso->dev_tracker); - - if (err != -EOPNOTSUPP) { + xso->type = XFRM_DEV_OFFLOAD_UNSPECIFIED; + + /* User explicitly requested packet offload mode and configured + * policy in addition to the XFRM state. So be civil to users, + * and return an error instead of taking fallback path. + * + * This WARN_ON() can be seen as a documentation for driver + * authors to do not return -EOPNOTSUPP in packet offload mode. + */ + WARN_ON(err == -EOPNOTSUPP && is_packet_offload); + if (err != -EOPNOTSUPP || is_packet_offload) { NL_SET_ERR_MSG(extack, "Device failed to offload this state"); return err; } @@ -295,6 +335,69 @@ int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, } EXPORT_SYMBOL_GPL(xfrm_dev_state_add); +int xfrm_dev_policy_add(struct net *net, struct xfrm_policy *xp, + struct xfrm_user_offload *xuo, u8 dir, + struct netlink_ext_ack *extack) +{ + struct xfrm_dev_offload *xdo = &xp->xdo; + struct net_device *dev; + int err; + + if (!xuo->flags || xuo->flags & ~XFRM_OFFLOAD_PACKET) { + /* We support only packet offload mode and it means + * that user must set XFRM_OFFLOAD_PACKET bit. + */ + NL_SET_ERR_MSG(extack, "Unrecognized flags in offload request"); + return -EINVAL; + } + + dev = dev_get_by_index(net, xuo->ifindex); + if (!dev) + return -EINVAL; + + if (!dev->xfrmdev_ops || !dev->xfrmdev_ops->xdo_dev_policy_add) { + xdo->dev = NULL; + dev_put(dev); + NL_SET_ERR_MSG(extack, "Policy offload is not supported"); + return -EINVAL; + } + + xdo->dev = dev; + netdev_tracker_alloc(dev, &xdo->dev_tracker, GFP_ATOMIC); + xdo->real_dev = dev; + xdo->type = XFRM_DEV_OFFLOAD_PACKET; + switch (dir) { + case XFRM_POLICY_IN: + xdo->dir = XFRM_DEV_OFFLOAD_IN; + break; + case XFRM_POLICY_OUT: + xdo->dir = XFRM_DEV_OFFLOAD_OUT; + break; + case XFRM_POLICY_FWD: + xdo->dir = XFRM_DEV_OFFLOAD_FWD; + break; + default: + xdo->dev = NULL; + dev_put(dev); + NL_SET_ERR_MSG(extack, "Unrecognized offload direction"); + return -EINVAL; + } + + err = dev->xfrmdev_ops->xdo_dev_policy_add(xp); + if (err) { + xdo->dev = NULL; + xdo->real_dev = NULL; + xdo->type = XFRM_DEV_OFFLOAD_UNSPECIFIED; + xdo->dir = 0; + netdev_put(dev, &xdo->dev_tracker); + NL_SET_ERR_MSG(extack, "Device failed to offload this policy"); + return err; + } + + return 0; +} +EXPORT_SYMBOL_GPL(xfrm_dev_policy_add); + bool xfrm_dev_offload_ok(struct sk_buff *skb, struct xfrm_state *x) { int mtu; @@ -305,8 +408,9 @@ bool xfrm_dev_offload_ok(struct sk_buff *skb, struct xfrm_state *x) if (!x->type_offload || x->encap) return false; - if ((!dev || (dev == xfrm_dst_path(dst)->dev)) && - (!xdst->child->xfrm)) { + if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET || + ((!dev || (dev == xfrm_dst_path(dst)->dev)) && + !xdst->child->xfrm)) { mtu = xfrm_state_mtu(x, xdst->child_mtu_cached); if (skb->len <= mtu) goto ok; @@ -397,8 +501,10 @@ static int xfrm_api_check(struct net_device *dev) static int xfrm_dev_down(struct net_device *dev) { - if (dev->features & NETIF_F_HW_ESP) + if (dev->features & NETIF_F_HW_ESP) { xfrm_dev_state_flush(dev_net(dev), dev, true); + xfrm_dev_policy_flush(dev_net(dev), dev, true); + } return NOTIFY_DONE; } diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c index 97074f6f2bde..c06e54a10540 100644 --- a/net/xfrm/xfrm_input.c +++ b/net/xfrm/xfrm_input.c @@ -671,6 +671,7 @@ resume: x->curlft.bytes += skb->len; x->curlft.packets++; + x->lastused = ktime_get_real_seconds(); spin_unlock(&x->lock); diff --git a/net/xfrm/xfrm_interface_bpf.c b/net/xfrm/xfrm_interface_bpf.c new file mode 100644 index 000000000000..1ef2162cebcf --- /dev/null +++ b/net/xfrm/xfrm_interface_bpf.c @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Unstable XFRM Helpers for TC-BPF hook + * + * These are called from SCHED_CLS BPF programs. Note that it is + * allowed to break compatibility for these functions since the interface they + * are exposed through to BPF programs is explicitly unstable. + */ + +#include <linux/bpf.h> +#include <linux/btf_ids.h> + +#include <net/dst_metadata.h> +#include <net/xfrm.h> + +/* bpf_xfrm_info - XFRM metadata information + * + * Members: + * @if_id - XFRM if_id: + * Transmit: if_id to be used in policy and state lookups + * Receive: if_id of the state matched for the incoming packet + * @link - Underlying device ifindex: + * Transmit: used as the underlying device in VRF routing + * Receive: the device on which the packet had been received + */ +struct bpf_xfrm_info { + u32 if_id; + int link; +}; + +__diag_push(); +__diag_ignore_all("-Wmissing-prototypes", + "Global functions as their definitions will be in xfrm_interface BTF"); + +/* bpf_skb_get_xfrm_info - Get XFRM metadata + * + * Parameters: + * @skb_ctx - Pointer to ctx (__sk_buff) in TC program + * Cannot be NULL + * @to - Pointer to memory to which the metadata will be copied + * Cannot be NULL + */ +__used noinline +int bpf_skb_get_xfrm_info(struct __sk_buff *skb_ctx, struct bpf_xfrm_info *to) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct xfrm_md_info *info; + + info = skb_xfrm_md_info(skb); + if (!info) + return -EINVAL; + + to->if_id = info->if_id; + to->link = info->link; + return 0; +} + +/* bpf_skb_get_xfrm_info - Set XFRM metadata + * + * Parameters: + * @skb_ctx - Pointer to ctx (__sk_buff) in TC program + * Cannot be NULL + * @from - Pointer to memory from which the metadata will be copied + * Cannot be NULL + */ +__used noinline +int bpf_skb_set_xfrm_info(struct __sk_buff *skb_ctx, + const struct bpf_xfrm_info *from) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct metadata_dst *md_dst; + struct xfrm_md_info *info; + + if (unlikely(skb_metadata_dst(skb))) + return -EINVAL; + + if (!xfrm_bpf_md_dst) { + struct metadata_dst __percpu *tmp; + + tmp = metadata_dst_alloc_percpu(0, METADATA_XFRM, GFP_ATOMIC); + if (!tmp) + return -ENOMEM; + if (cmpxchg(&xfrm_bpf_md_dst, NULL, tmp)) + metadata_dst_free_percpu(tmp); + } + md_dst = this_cpu_ptr(xfrm_bpf_md_dst); + + info = &md_dst->u.xfrm_info; + + info->if_id = from->if_id; + info->link = from->link; + skb_dst_force(skb); + info->dst_orig = skb_dst(skb); + + dst_hold((struct dst_entry *)md_dst); + skb_dst_set(skb, (struct dst_entry *)md_dst); + return 0; +} + +__diag_pop() + +BTF_SET8_START(xfrm_ifc_kfunc_set) +BTF_ID_FLAGS(func, bpf_skb_get_xfrm_info) +BTF_ID_FLAGS(func, bpf_skb_set_xfrm_info) +BTF_SET8_END(xfrm_ifc_kfunc_set) + +static const struct btf_kfunc_id_set xfrm_interface_kfunc_set = { + .owner = THIS_MODULE, + .set = &xfrm_ifc_kfunc_set, +}; + +int __init register_xfrm_interface_bpf(void) +{ + return register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, + &xfrm_interface_kfunc_set); +} diff --git a/net/xfrm/xfrm_interface.c b/net/xfrm/xfrm_interface_core.c index 5a67b120c4db..1f99dc469027 100644 --- a/net/xfrm/xfrm_interface.c +++ b/net/xfrm/xfrm_interface_core.c @@ -396,6 +396,14 @@ xfrmi_xmit2(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) if_id = md_info->if_id; fl->flowi_oif = md_info->link; + if (md_info->dst_orig) { + struct dst_entry *tmp_dst = dst; + + dst = md_info->dst_orig; + skb_dst_set(skb, dst); + md_info->dst_orig = NULL; + dst_release(tmp_dst); + } } else { if_id = xi->p.if_id; } @@ -1162,12 +1170,18 @@ static int __init xfrmi_init(void) if (err < 0) goto rtnl_link_failed; + err = register_xfrm_interface_bpf(); + if (err < 0) + goto kfunc_failed; + lwtunnel_encap_add_ops(&xfrmi_encap_ops, LWTUNNEL_ENCAP_XFRM); xfrm_if_register_cb(&xfrm_if_cb); return err; +kfunc_failed: + rtnl_link_unregister(&xfrmi_link_ops); rtnl_link_failed: xfrmi6_fini(); xfrmi6_failed: diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c index 9a5e79a38c67..ff114d68cc43 100644 --- a/net/xfrm/xfrm_output.c +++ b/net/xfrm/xfrm_output.c @@ -209,8 +209,6 @@ static int xfrm6_ro_output(struct xfrm_state *x, struct sk_buff *skb) __skb_pull(skb, hdr_len); memmove(ipv6_hdr(skb), iph, hdr_len); - x->lastused = ktime_get_real_seconds(); - return 0; #else WARN_ON_ONCE(1); @@ -494,7 +492,7 @@ static int xfrm_output_one(struct sk_buff *skb, int err) struct xfrm_state *x = dst->xfrm; struct net *net = xs_net(x); - if (err <= 0) + if (err <= 0 || x->xso.type == XFRM_DEV_OFFLOAD_PACKET) goto resume; do { @@ -534,6 +532,7 @@ static int xfrm_output_one(struct sk_buff *skb, int err) x->curlft.bytes += skb->len; x->curlft.packets++; + x->lastused = ktime_get_real_seconds(); spin_unlock_bh(&x->lock); @@ -718,6 +717,16 @@ int xfrm_output(struct sock *sk, struct sk_buff *skb) break; } + if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET) { + if (!xfrm_dev_offload_ok(skb, x)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + kfree_skb(skb); + return -EHOSTUNREACH; + } + + return xfrm_output_resume(sk, skb, 0); + } + secpath_reset(skb); if (xfrm_dev_offload_ok(skb, x)) { diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index e392d8d05e0c..e9eb82c5457d 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -425,6 +425,7 @@ void xfrm_policy_destroy(struct xfrm_policy *policy) if (del_timer(&policy->timer) || del_timer(&policy->polq.hold_timer)) BUG(); + xfrm_dev_policy_free(policy); call_rcu(&policy->rcu, xfrm_policy_destroy_rcu); } EXPORT_SYMBOL(xfrm_policy_destroy); @@ -535,7 +536,7 @@ redo: __get_hash_thresh(net, pol->family, dir, &dbits, &sbits); h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr, pol->family, nhashmask, dbits, sbits); - if (!entry0) { + if (!entry0 || pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) { hlist_del_rcu(&pol->bydst); hlist_add_head_rcu(&pol->bydst, ndsttable + h); h0 = h; @@ -605,7 +606,7 @@ static void xfrm_bydst_resize(struct net *net, int dir) xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head)); } -static void xfrm_byidx_resize(struct net *net, int total) +static void xfrm_byidx_resize(struct net *net) { unsigned int hmask = net->xfrm.policy_idx_hmask; unsigned int nhashmask = xfrm_new_hash_mask(hmask); @@ -683,7 +684,7 @@ static void xfrm_hash_resize(struct work_struct *work) xfrm_bydst_resize(net, dir); } if (xfrm_byidx_should_resize(net, total)) - xfrm_byidx_resize(net, total); + xfrm_byidx_resize(net); mutex_unlock(&hash_resize_mutex); } @@ -866,7 +867,7 @@ static void xfrm_policy_inexact_list_reinsert(struct net *net, break; } - if (newpos) + if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET) hlist_add_behind_rcu(&policy->bydst, newpos); else hlist_add_head_rcu(&policy->bydst, &n->hhead); @@ -1347,7 +1348,7 @@ static void xfrm_hash_rebuild(struct work_struct *work) else break; } - if (newpos) + if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET) hlist_add_behind_rcu(&policy->bydst, newpos); else hlist_add_head_rcu(&policy->bydst, chain); @@ -1524,7 +1525,7 @@ static void xfrm_policy_insert_inexact_list(struct hlist_head *chain, break; } - if (newpos) + if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET) hlist_add_behind_rcu(&policy->bydst_inexact_list, newpos); else hlist_add_head_rcu(&policy->bydst_inexact_list, chain); @@ -1561,9 +1562,12 @@ static struct xfrm_policy *xfrm_policy_insert_list(struct hlist_head *chain, break; } - if (newpos) + if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET) hlist_add_behind_rcu(&policy->bydst, &newpos->bydst); else + /* Packet offload policies enter to the head + * to speed-up lookups. + */ hlist_add_head_rcu(&policy->bydst, chain); return delpol; @@ -1769,12 +1773,41 @@ xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid) } return err; } + +static inline int xfrm_dev_policy_flush_secctx_check(struct net *net, + struct net_device *dev, + bool task_valid) +{ + struct xfrm_policy *pol; + int err = 0; + + list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) { + if (pol->walk.dead || + xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX || + pol->xdo.dev != dev) + continue; + + err = security_xfrm_policy_delete(pol->security); + if (err) { + xfrm_audit_policy_delete(pol, 0, task_valid); + return err; + } + } + return err; +} #else static inline int xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid) { return 0; } + +static inline int xfrm_dev_policy_flush_secctx_check(struct net *net, + struct net_device *dev, + bool task_valid) +{ + return 0; +} #endif int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) @@ -1814,6 +1847,44 @@ out: } EXPORT_SYMBOL(xfrm_policy_flush); +int xfrm_dev_policy_flush(struct net *net, struct net_device *dev, + bool task_valid) +{ + int dir, err = 0, cnt = 0; + struct xfrm_policy *pol; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + + err = xfrm_dev_policy_flush_secctx_check(net, dev, task_valid); + if (err) + goto out; + +again: + list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) { + dir = xfrm_policy_id2dir(pol->index); + if (pol->walk.dead || + dir >= XFRM_POLICY_MAX || + pol->xdo.dev != dev) + continue; + + __xfrm_policy_unlink(pol, dir); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + cnt++; + xfrm_audit_policy_delete(pol, 1, task_valid); + xfrm_policy_kill(pol); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + goto again; + } + if (cnt) + __xfrm_policy_inexact_flush(net); + else + err = -ESRCH; +out: + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return err; +} +EXPORT_SYMBOL(xfrm_dev_policy_flush); + int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk, int (*func)(struct xfrm_policy *, int, int, void*), void *data) @@ -2113,6 +2184,9 @@ static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type, break; } } + if (ret && ret->xdo.type == XFRM_DEV_OFFLOAD_PACKET) + goto skip_inexact; + bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id); if (!bin || !xfrm_policy_find_inexact_candidates(&cand, bin, saddr, daddr)) @@ -2245,6 +2319,7 @@ int xfrm_policy_delete(struct xfrm_policy *pol, int dir) pol = __xfrm_policy_unlink(pol, dir); spin_unlock_bh(&net->xfrm.xfrm_policy_lock); if (pol) { + xfrm_dev_policy_delete(pol); xfrm_policy_kill(pol); return 0; } @@ -4333,7 +4408,8 @@ static int migrate_tmpl_match(const struct xfrm_migrate *m, const struct xfrm_tm /* update endpoint address(es) of template(s) */ static int xfrm_policy_migrate(struct xfrm_policy *pol, - struct xfrm_migrate *m, int num_migrate) + struct xfrm_migrate *m, int num_migrate, + struct netlink_ext_ack *extack) { struct xfrm_migrate *mp; int i, j, n = 0; @@ -4341,6 +4417,7 @@ static int xfrm_policy_migrate(struct xfrm_policy *pol, write_lock_bh(&pol->lock); if (unlikely(pol->walk.dead)) { /* target policy has been deleted */ + NL_SET_ERR_MSG(extack, "Target policy not found"); write_unlock_bh(&pol->lock); return -ENOENT; } @@ -4372,17 +4449,22 @@ static int xfrm_policy_migrate(struct xfrm_policy *pol, return 0; } -static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate) +static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate, + struct netlink_ext_ack *extack) { int i, j; - if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH) + if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH) { + NL_SET_ERR_MSG(extack, "Invalid number of SAs to migrate, must be 0 < num <= XFRM_MAX_DEPTH (6)"); return -EINVAL; + } for (i = 0; i < num_migrate; i++) { if (xfrm_addr_any(&m[i].new_daddr, m[i].new_family) || - xfrm_addr_any(&m[i].new_saddr, m[i].new_family)) + xfrm_addr_any(&m[i].new_saddr, m[i].new_family)) { + NL_SET_ERR_MSG(extack, "Addresses in the MIGRATE attribute's list cannot be null"); return -EINVAL; + } /* check if there is any duplicated entry */ for (j = i + 1; j < num_migrate; j++) { @@ -4393,8 +4475,10 @@ static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate) m[i].proto == m[j].proto && m[i].mode == m[j].mode && m[i].reqid == m[j].reqid && - m[i].old_family == m[j].old_family) + m[i].old_family == m[j].old_family) { + NL_SET_ERR_MSG(extack, "Entries in the MIGRATE attribute's list must be unique"); return -EINVAL; + } } } @@ -4404,7 +4488,8 @@ static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate) int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, struct xfrm_migrate *m, int num_migrate, struct xfrm_kmaddress *k, struct net *net, - struct xfrm_encap_tmpl *encap, u32 if_id) + struct xfrm_encap_tmpl *encap, u32 if_id, + struct netlink_ext_ack *extack) { int i, err, nx_cur = 0, nx_new = 0; struct xfrm_policy *pol = NULL; @@ -4414,16 +4499,20 @@ int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, struct xfrm_migrate *mp; /* Stage 0 - sanity checks */ - if ((err = xfrm_migrate_check(m, num_migrate)) < 0) + err = xfrm_migrate_check(m, num_migrate, extack); + if (err < 0) goto out; if (dir >= XFRM_POLICY_MAX) { + NL_SET_ERR_MSG(extack, "Invalid policy direction"); err = -EINVAL; goto out; } /* Stage 1 - find policy */ - if ((pol = xfrm_migrate_policy_find(sel, dir, type, net, if_id)) == NULL) { + pol = xfrm_migrate_policy_find(sel, dir, type, net, if_id); + if (!pol) { + NL_SET_ERR_MSG(extack, "Target policy not found"); err = -ENOENT; goto out; } @@ -4445,7 +4534,8 @@ int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, } /* Stage 3 - update policy */ - if ((err = xfrm_policy_migrate(pol, m, num_migrate)) < 0) + err = xfrm_policy_migrate(pol, m, num_migrate, extack); + if (err < 0) goto restore_state; /* Stage 4 - delete old state(s) */ diff --git a/net/xfrm/xfrm_replay.c b/net/xfrm/xfrm_replay.c index 9f4d42eb090f..ce56d659c55a 100644 --- a/net/xfrm/xfrm_replay.c +++ b/net/xfrm/xfrm_replay.c @@ -714,7 +714,7 @@ static int xfrm_replay_overflow_offload_esn(struct xfrm_state *x, struct sk_buff oseq += skb_shinfo(skb)->gso_segs; } - if (unlikely(oseq < replay_esn->oseq)) { + if (unlikely(xo->seq.low < replay_esn->oseq)) { XFRM_SKB_CB(skb)->seq.output.hi = ++oseq_hi; xo->seq.hi = oseq_hi; replay_esn->oseq_hi = oseq_hi; diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 3d2fe7712ac5..89c731f4f0c7 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -84,6 +84,25 @@ static unsigned int xfrm_seq_hash(struct net *net, u32 seq) return __xfrm_seq_hash(seq, net->xfrm.state_hmask); } +#define XFRM_STATE_INSERT(by, _n, _h, _type) \ + { \ + struct xfrm_state *_x = NULL; \ + \ + if (_type != XFRM_DEV_OFFLOAD_PACKET) { \ + hlist_for_each_entry_rcu(_x, _h, by) { \ + if (_x->xso.type == XFRM_DEV_OFFLOAD_PACKET) \ + continue; \ + break; \ + } \ + } \ + \ + if (!_x || _x->xso.type == XFRM_DEV_OFFLOAD_PACKET) \ + /* SAD is empty or consist from HW SAs only */ \ + hlist_add_head_rcu(_n, _h); \ + else \ + hlist_add_before_rcu(_n, &_x->by); \ + } + static void xfrm_hash_transfer(struct hlist_head *list, struct hlist_head *ndsttable, struct hlist_head *nsrctable, @@ -100,23 +119,25 @@ static void xfrm_hash_transfer(struct hlist_head *list, h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr, x->props.reqid, x->props.family, nhashmask); - hlist_add_head_rcu(&x->bydst, ndsttable + h); + XFRM_STATE_INSERT(bydst, &x->bydst, ndsttable + h, x->xso.type); h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr, x->props.family, nhashmask); - hlist_add_head_rcu(&x->bysrc, nsrctable + h); + XFRM_STATE_INSERT(bysrc, &x->bysrc, nsrctable + h, x->xso.type); if (x->id.spi) { h = __xfrm_spi_hash(&x->id.daddr, x->id.spi, x->id.proto, x->props.family, nhashmask); - hlist_add_head_rcu(&x->byspi, nspitable + h); + XFRM_STATE_INSERT(byspi, &x->byspi, nspitable + h, + x->xso.type); } if (x->km.seq) { h = __xfrm_seq_hash(x->km.seq, nhashmask); - hlist_add_head_rcu(&x->byseq, nseqtable + h); + XFRM_STATE_INSERT(byseq, &x->byseq, nseqtable + h, + x->xso.type); } } } @@ -549,6 +570,8 @@ static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me) int err = 0; spin_lock(&x->lock); + xfrm_dev_state_update_curlft(x); + if (x->km.state == XFRM_STATE_DEAD) goto out; if (x->km.state == XFRM_STATE_EXPIRED) @@ -951,6 +974,49 @@ xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl, x->props.family = tmpl->encap_family; } +static struct xfrm_state *__xfrm_state_lookup_all(struct net *net, u32 mark, + const xfrm_address_t *daddr, + __be32 spi, u8 proto, + unsigned short family, + struct xfrm_dev_offload *xdo) +{ + unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family); + struct xfrm_state *x; + + hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) { +#ifdef CONFIG_XFRM_OFFLOAD + if (xdo->type == XFRM_DEV_OFFLOAD_PACKET) { + if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET) + /* HW states are in the head of list, there is + * no need to iterate further. + */ + break; + + /* Packet offload: both policy and SA should + * have same device. + */ + if (xdo->dev != x->xso.dev) + continue; + } else if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET) + /* Skip HW policy for SW lookups */ + continue; +#endif + if (x->props.family != family || + x->id.spi != spi || + x->id.proto != proto || + !xfrm_addr_equal(&x->id.daddr, daddr, family)) + continue; + + if ((mark & x->mark.m) != x->mark.v) + continue; + if (!xfrm_state_hold_rcu(x)) + continue; + return x; + } + + return NULL; +} + static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32 spi, u8 proto, @@ -1092,6 +1158,23 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, rcu_read_lock(); h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) { +#ifdef CONFIG_XFRM_OFFLOAD + if (pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) { + if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET) + /* HW states are in the head of list, there is + * no need to iterate further. + */ + break; + + /* Packet offload: both policy and SA should + * have same device. + */ + if (pol->xdo.dev != x->xso.dev) + continue; + } else if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET) + /* Skip HW policy for SW lookups */ + continue; +#endif if (x->props.family == encap_family && x->props.reqid == tmpl->reqid && (mark & x->mark.m) == x->mark.v && @@ -1109,6 +1192,23 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family); hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) { +#ifdef CONFIG_XFRM_OFFLOAD + if (pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) { + if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET) + /* HW states are in the head of list, there is + * no need to iterate further. + */ + break; + + /* Packet offload: both policy and SA should + * have same device. + */ + if (pol->xdo.dev != x->xso.dev) + continue; + } else if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET) + /* Skip HW policy for SW lookups */ + continue; +#endif if (x->props.family == encap_family && x->props.reqid == tmpl->reqid && (mark & x->mark.m) == x->mark.v && @@ -1126,8 +1226,10 @@ found: x = best; if (!x && !error && !acquire_in_progress) { if (tmpl->id.spi && - (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi, - tmpl->id.proto, encap_family)) != NULL) { + (x0 = __xfrm_state_lookup_all(net, mark, daddr, + tmpl->id.spi, tmpl->id.proto, + encap_family, + &pol->xdo)) != NULL) { to_put = x0; error = -EEXIST; goto out; @@ -1161,21 +1263,53 @@ found: x = NULL; goto out; } - +#ifdef CONFIG_XFRM_OFFLOAD + if (pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) { + struct xfrm_dev_offload *xdo = &pol->xdo; + struct xfrm_dev_offload *xso = &x->xso; + + xso->type = XFRM_DEV_OFFLOAD_PACKET; + xso->dir = xdo->dir; + xso->dev = xdo->dev; + xso->real_dev = xdo->real_dev; + netdev_tracker_alloc(xso->dev, &xso->dev_tracker, + GFP_ATOMIC); + error = xso->dev->xfrmdev_ops->xdo_dev_state_add(x); + if (error) { + xso->dir = 0; + netdev_put(xso->dev, &xso->dev_tracker); + xso->dev = NULL; + xso->real_dev = NULL; + xso->type = XFRM_DEV_OFFLOAD_UNSPECIFIED; + x->km.state = XFRM_STATE_DEAD; + to_put = x; + x = NULL; + goto out; + } + } +#endif if (km_query(x, tmpl, pol) == 0) { spin_lock_bh(&net->xfrm.xfrm_state_lock); x->km.state = XFRM_STATE_ACQ; list_add(&x->km.all, &net->xfrm.state_all); - hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + XFRM_STATE_INSERT(bydst, &x->bydst, + net->xfrm.state_bydst + h, + x->xso.type); h = xfrm_src_hash(net, daddr, saddr, encap_family); - hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + XFRM_STATE_INSERT(bysrc, &x->bysrc, + net->xfrm.state_bysrc + h, + x->xso.type); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family); - hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + XFRM_STATE_INSERT(byspi, &x->byspi, + net->xfrm.state_byspi + h, + x->xso.type); } if (x->km.seq) { h = xfrm_seq_hash(net, x->km.seq); - hlist_add_head_rcu(&x->byseq, net->xfrm.state_byseq + h); + XFRM_STATE_INSERT(byseq, &x->byseq, + net->xfrm.state_byseq + h, + x->xso.type); } x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; hrtimer_start(&x->mtimer, @@ -1185,6 +1319,18 @@ found: xfrm_hash_grow_check(net, x->bydst.next != NULL); spin_unlock_bh(&net->xfrm.xfrm_state_lock); } else { +#ifdef CONFIG_XFRM_OFFLOAD + struct xfrm_dev_offload *xso = &x->xso; + + if (xso->type == XFRM_DEV_OFFLOAD_PACKET) { + xso->dev->xfrmdev_ops->xdo_dev_state_delete(x); + xso->dir = 0; + netdev_put(xso->dev, &xso->dev_tracker); + xso->dev = NULL; + xso->real_dev = NULL; + xso->type = XFRM_DEV_OFFLOAD_UNSPECIFIED; + } +#endif x->km.state = XFRM_STATE_DEAD; to_put = x; x = NULL; @@ -1280,22 +1426,26 @@ static void __xfrm_state_insert(struct xfrm_state *x) h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr, x->props.reqid, x->props.family); - hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + XFRM_STATE_INSERT(bydst, &x->bydst, net->xfrm.state_bydst + h, + x->xso.type); h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family); - hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + XFRM_STATE_INSERT(bysrc, &x->bysrc, net->xfrm.state_bysrc + h, + x->xso.type); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); - hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + XFRM_STATE_INSERT(byspi, &x->byspi, net->xfrm.state_byspi + h, + x->xso.type); } if (x->km.seq) { h = xfrm_seq_hash(net, x->km.seq); - hlist_add_head_rcu(&x->byseq, net->xfrm.state_byseq + h); + XFRM_STATE_INSERT(byseq, &x->byseq, net->xfrm.state_byseq + h, + x->xso.type); } hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL_SOFT); @@ -1409,9 +1559,11 @@ static struct xfrm_state *__find_acq_core(struct net *net, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL_SOFT); list_add(&x->km.all, &net->xfrm.state_all); - hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + XFRM_STATE_INSERT(bydst, &x->bydst, net->xfrm.state_bydst + h, + x->xso.type); h = xfrm_src_hash(net, daddr, saddr, family); - hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + XFRM_STATE_INSERT(bysrc, &x->bysrc, net->xfrm.state_bysrc + h, + x->xso.type); net->xfrm.state_num++; @@ -1786,6 +1938,8 @@ EXPORT_SYMBOL(xfrm_state_update); int xfrm_state_check_expire(struct xfrm_state *x) { + xfrm_dev_state_update_curlft(x); + if (!x->curlft.use_time) x->curlft.use_time = ktime_get_real_seconds(); @@ -2017,7 +2171,7 @@ u32 xfrm_get_acqseq(void) } EXPORT_SYMBOL(xfrm_get_acqseq); -int verify_spi_info(u8 proto, u32 min, u32 max) +int verify_spi_info(u8 proto, u32 min, u32 max, struct netlink_ext_ack *extack) { switch (proto) { case IPPROTO_AH: @@ -2026,22 +2180,28 @@ int verify_spi_info(u8 proto, u32 min, u32 max) case IPPROTO_COMP: /* IPCOMP spi is 16-bits. */ - if (max >= 0x10000) + if (max >= 0x10000) { + NL_SET_ERR_MSG(extack, "IPCOMP SPI must be <= 65535"); return -EINVAL; + } break; default: + NL_SET_ERR_MSG(extack, "Invalid protocol, must be one of AH, ESP, IPCOMP"); return -EINVAL; } - if (min > max) + if (min > max) { + NL_SET_ERR_MSG(extack, "Invalid SPI range: min > max"); return -EINVAL; + } return 0; } EXPORT_SYMBOL(verify_spi_info); -int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) +int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high, + struct netlink_ext_ack *extack) { struct net *net = xs_net(x); unsigned int h; @@ -2053,8 +2213,10 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) u32 mark = x->mark.v & x->mark.m; spin_lock_bh(&x->lock); - if (x->km.state == XFRM_STATE_DEAD) + if (x->km.state == XFRM_STATE_DEAD) { + NL_SET_ERR_MSG(extack, "Target ACQUIRE is in DEAD state"); goto unlock; + } err = 0; if (x->id.spi) @@ -2065,6 +2227,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) if (minspi == maxspi) { x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family); if (x0) { + NL_SET_ERR_MSG(extack, "Requested SPI is already in use"); xfrm_state_put(x0); goto unlock; } @@ -2072,7 +2235,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) } else { u32 spi = 0; for (h = 0; h < high-low+1; h++) { - spi = low + prandom_u32_max(high - low + 1); + spi = get_random_u32_inclusive(low, high); x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family); if (x0 == NULL) { newspi = htonl(spi); @@ -2085,10 +2248,13 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) spin_lock_bh(&net->xfrm.xfrm_state_lock); x->id.spi = newspi; h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); - hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + XFRM_STATE_INSERT(byspi, &x->byspi, net->xfrm.state_byspi + h, + x->xso.type); spin_unlock_bh(&net->xfrm.xfrm_state_lock); err = 0; + } else { + NL_SET_ERR_MSG(extack, "No SPI available in the requested range"); } unlock: diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index e73f9efc54c1..cf5172d4ce68 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -515,7 +515,8 @@ static int attach_aead(struct xfrm_state *x, struct nlattr *rta, } static inline int xfrm_replay_verify_len(struct xfrm_replay_state_esn *replay_esn, - struct nlattr *rp) + struct nlattr *rp, + struct netlink_ext_ack *extack) { struct xfrm_replay_state_esn *up; unsigned int ulen; @@ -528,13 +529,25 @@ static inline int xfrm_replay_verify_len(struct xfrm_replay_state_esn *replay_es /* Check the overall length and the internal bitmap length to avoid * potential overflow. */ - if (nla_len(rp) < (int)ulen || - xfrm_replay_state_esn_len(replay_esn) != ulen || - replay_esn->bmp_len != up->bmp_len) + if (nla_len(rp) < (int)ulen) { + NL_SET_ERR_MSG(extack, "ESN attribute is too short"); return -EINVAL; + } + + if (xfrm_replay_state_esn_len(replay_esn) != ulen) { + NL_SET_ERR_MSG(extack, "New ESN size doesn't match the existing SA's ESN size"); + return -EINVAL; + } + + if (replay_esn->bmp_len != up->bmp_len) { + NL_SET_ERR_MSG(extack, "New ESN bitmap size doesn't match the existing SA's ESN bitmap"); + return -EINVAL; + } - if (up->replay_window > up->bmp_len * sizeof(__u32) * 8) + if (up->replay_window > up->bmp_len * sizeof(__u32) * 8) { + NL_SET_ERR_MSG(extack, "ESN replay window is longer than the bitmap"); return -EINVAL; + } return 0; } @@ -862,12 +875,12 @@ static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh, goto out; if (xfrm_state_kern(x)) { + NL_SET_ERR_MSG(extack, "SA is in use by tunnels"); err = -EPERM; goto out; } err = xfrm_state_delete(x); - if (err < 0) goto out; @@ -943,6 +956,8 @@ static int copy_user_offload(struct xfrm_dev_offload *xso, struct sk_buff *skb) xuo->ifindex = xso->dev->ifindex; if (xso->dir == XFRM_DEV_OFFLOAD_IN) xuo->flags = XFRM_OFFLOAD_INBOUND; + if (xso->type == XFRM_DEV_OFFLOAD_PACKET) + xuo->flags |= XFRM_OFFLOAD_PACKET; return 0; } @@ -1354,20 +1369,28 @@ static int xfrm_set_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh, if (attrs[XFRMA_SPD_IPV4_HTHRESH]) { struct nlattr *rta = attrs[XFRMA_SPD_IPV4_HTHRESH]; - if (nla_len(rta) < sizeof(*thresh4)) + if (nla_len(rta) < sizeof(*thresh4)) { + NL_SET_ERR_MSG(extack, "Invalid SPD_IPV4_HTHRESH attribute length"); return -EINVAL; + } thresh4 = nla_data(rta); - if (thresh4->lbits > 32 || thresh4->rbits > 32) + if (thresh4->lbits > 32 || thresh4->rbits > 32) { + NL_SET_ERR_MSG(extack, "Invalid hash threshold (must be <= 32 for IPv4)"); return -EINVAL; + } } if (attrs[XFRMA_SPD_IPV6_HTHRESH]) { struct nlattr *rta = attrs[XFRMA_SPD_IPV6_HTHRESH]; - if (nla_len(rta) < sizeof(*thresh6)) + if (nla_len(rta) < sizeof(*thresh6)) { + NL_SET_ERR_MSG(extack, "Invalid SPD_IPV6_HTHRESH attribute length"); return -EINVAL; + } thresh6 = nla_data(rta); - if (thresh6->lbits > 128 || thresh6->rbits > 128) + if (thresh6->lbits > 128 || thresh6->rbits > 128) { + NL_SET_ERR_MSG(extack, "Invalid hash threshold (must be <= 128 for IPv6)"); return -EINVAL; + } } if (thresh4 || thresh6) { @@ -1510,7 +1533,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, u32 if_id = 0; p = nlmsg_data(nlh); - err = verify_spi_info(p->info.id.proto, p->min, p->max); + err = verify_spi_info(p->info.id.proto, p->min, p->max, extack); if (err) goto out_noput; @@ -1538,10 +1561,12 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, &p->info.saddr, 1, family); err = -ENOENT; - if (x == NULL) + if (!x) { + NL_SET_ERR_MSG(extack, "Target ACQUIRE not found"); goto out_noput; + } - err = xfrm_alloc_spi(x, p->min, p->max); + err = xfrm_alloc_spi(x, p->min, p->max, extack); if (err) goto out; @@ -1867,6 +1892,15 @@ static struct xfrm_policy *xfrm_policy_construct(struct net *net, if (attrs[XFRMA_IF_ID]) xp->if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + /* configure the hardware if offload is requested */ + if (attrs[XFRMA_OFFLOAD_DEV]) { + err = xfrm_dev_policy_add(net, xp, + nla_data(attrs[XFRMA_OFFLOAD_DEV]), + p->dir, extack); + if (err) + goto error; + } + return xp; error: *errp = err; @@ -1906,6 +1940,7 @@ static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh, xfrm_audit_policy_add(xp, err ? 0 : 1, true); if (err) { + xfrm_dev_policy_delete(xp); security_xfrm_policy_free(xp->security); kfree(xp); return err; @@ -2018,6 +2053,8 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr err = xfrm_mark_put(skb, &xp->mark); if (!err) err = xfrm_if_id_put(skb, xp->if_id); + if (!err && xp->xdo.dev) + err = copy_user_offload(&xp->xdo, skb); if (err) { nlmsg_cancel(skb, nlh); return err; @@ -2433,12 +2470,16 @@ static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh, struct nlattr *et = attrs[XFRMA_ETIMER_THRESH]; struct nlattr *rt = attrs[XFRMA_REPLAY_THRESH]; - if (!lt && !rp && !re && !et && !rt) + if (!lt && !rp && !re && !et && !rt) { + NL_SET_ERR_MSG(extack, "Missing required attribute for AE"); return err; + } /* pedantic mode - thou shalt sayeth replaceth */ - if (!(nlh->nlmsg_flags&NLM_F_REPLACE)) + if (!(nlh->nlmsg_flags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG(extack, "NLM_F_REPLACE flag is required"); return err; + } mark = xfrm_mark_get(attrs, &m); @@ -2446,10 +2487,12 @@ static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh, if (x == NULL) return -ESRCH; - if (x->km.state != XFRM_STATE_VALID) + if (x->km.state != XFRM_STATE_VALID) { + NL_SET_ERR_MSG(extack, "SA must be in VALID state"); goto out; + } - err = xfrm_replay_verify_len(x->replay_esn, re); + err = xfrm_replay_verify_len(x->replay_esn, re, extack); if (err) goto out; @@ -2584,8 +2627,11 @@ static int xfrm_add_sa_expire(struct sk_buff *skb, struct nlmsghdr *nlh, spin_lock_bh(&x->lock); err = -EINVAL; - if (x->km.state != XFRM_STATE_VALID) + if (x->km.state != XFRM_STATE_VALID) { + NL_SET_ERR_MSG(extack, "SA must be in VALID state"); goto out; + } + km_state_expired(x, ue->hard, nlh->nlmsg_pid); if (ue->hard) { @@ -2665,7 +2711,8 @@ nomem: #ifdef CONFIG_XFRM_MIGRATE static int copy_from_user_migrate(struct xfrm_migrate *ma, struct xfrm_kmaddress *k, - struct nlattr **attrs, int *num) + struct nlattr **attrs, int *num, + struct netlink_ext_ack *extack) { struct nlattr *rt = attrs[XFRMA_MIGRATE]; struct xfrm_user_migrate *um; @@ -2684,8 +2731,10 @@ static int copy_from_user_migrate(struct xfrm_migrate *ma, um = nla_data(rt); num_migrate = nla_len(rt) / sizeof(*um); - if (num_migrate <= 0 || num_migrate > XFRM_MAX_DEPTH) + if (num_migrate <= 0 || num_migrate > XFRM_MAX_DEPTH) { + NL_SET_ERR_MSG(extack, "Invalid number of SAs to migrate, must be 0 < num <= XFRM_MAX_DEPTH (6)"); return -EINVAL; + } for (i = 0; i < num_migrate; i++, um++, ma++) { memcpy(&ma->old_daddr, &um->old_daddr, sizeof(ma->old_daddr)); @@ -2718,8 +2767,10 @@ static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, struct xfrm_encap_tmpl *encap = NULL; u32 if_id = 0; - if (attrs[XFRMA_MIGRATE] == NULL) + if (!attrs[XFRMA_MIGRATE]) { + NL_SET_ERR_MSG(extack, "Missing required MIGRATE attribute"); return -EINVAL; + } kmp = attrs[XFRMA_KMADDRESS] ? &km : NULL; @@ -2727,7 +2778,7 @@ static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, if (err) return err; - err = copy_from_user_migrate((struct xfrm_migrate *)m, kmp, attrs, &n); + err = copy_from_user_migrate(m, kmp, attrs, &n, extack); if (err) return err; @@ -2744,7 +2795,8 @@ static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, if (attrs[XFRMA_IF_ID]) if_id = nla_get_u32(attrs[XFRMA_IF_ID]); - err = xfrm_migrate(&pi->sel, pi->dir, type, m, n, kmp, net, encap, if_id); + err = xfrm_migrate(&pi->sel, pi->dir, type, m, n, kmp, net, encap, + if_id, extack); kfree(encap); @@ -3341,6 +3393,8 @@ static int build_acquire(struct sk_buff *skb, struct xfrm_state *x, err = xfrm_mark_put(skb, &xp->mark); if (!err) err = xfrm_if_id_put(skb, xp->if_id); + if (!err && xp->xdo.dev) + err = copy_user_offload(&xp->xdo, skb); if (err) { nlmsg_cancel(skb, nlh); return err; @@ -3459,6 +3513,8 @@ static int build_polexpire(struct sk_buff *skb, struct xfrm_policy *xp, err = xfrm_mark_put(skb, &xp->mark); if (!err) err = xfrm_if_id_put(skb, xp->if_id); + if (!err && xp->xdo.dev) + err = copy_user_offload(&xp->xdo, skb); if (err) { nlmsg_cancel(skb, nlh); return err; @@ -3542,6 +3598,8 @@ static int xfrm_notify_policy(struct xfrm_policy *xp, int dir, const struct km_e err = xfrm_mark_put(skb, &xp->mark); if (!err) err = xfrm_if_id_put(skb, xp->if_id); + if (!err && xp->xdo.dev) + err = copy_user_offload(&xp->xdo, skb); if (err) goto out_free_skb; |